From 67f6a90fba62b7af9653d4c185039d84573d235a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 19 Jun 2025 06:58:58 +0000 Subject: [PATCH 01/57] refactor: update dj package to use new name --- src/orcapod/dj/mapper.py | 6 +++--- src/orcapod/dj/operation.py | 4 ++-- src/orcapod/dj/source.py | 12 ++++++------ src/orcapod/dj/stream.py | 2 +- src/orcapod/dj/tracker.py | 12 ++++++------ src/orcapod/pod/core.py | 4 ++++ 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/orcapod/dj/mapper.py b/src/orcapod/dj/mapper.py index a38fdaf..efec07c 100644 --- a/src/orcapod/dj/mapper.py +++ b/src/orcapod/dj/mapper.py @@ -1,18 +1,18 @@ import warnings from typing import Optional -from orcapod.mappers import Join, MapPackets, Mapper, MapTags +from orcapod.core.operators import Join, MapPackets, MapTags, Operator from .operation import QueryOperation from .stream import QueryStream -class QueryMapper(QueryOperation, Mapper): +class QueryMapper(QueryOperation, Operator): """ A special type of mapper that returns and works with QueryStreams """ -def convert_to_query_mapper(operation: Mapper) -> QueryMapper: +def convert_to_query_mapper(operation: Operator) -> QueryMapper: """ Convert a generic mapper to an equivalent, Query mapper """ diff --git a/src/orcapod/dj/operation.py b/src/orcapod/dj/operation.py index d4d5a81..70b218e 100644 --- a/src/orcapod/dj/operation.py +++ b/src/orcapod/dj/operation.py @@ -1,8 +1,8 @@ -from ..base import Operation +from orcapod.core.base import Kernel from .stream import QueryStream -class QueryOperation(Operation): +class QueryOperation(Kernel): """ A special type of operation that returns and works with QueryStreams diff --git a/src/orcapod/dj/source.py b/src/orcapod/dj/source.py index 8af3f23..0eaa6dc 100644 --- a/src/orcapod/dj/source.py +++ b/src/orcapod/dj/source.py @@ -6,12 +6,12 @@ from orcapod.hashing import hash_to_uuid -from orcapod.sources import Source -from orcapod.streams import SyncStream -from ..utils.name import pascal_to_snake, snake_to_pascal -from ..utils.stream_utils import common_elements -from .operation import QueryOperation -from .stream import QueryStream, TableCachedStream, TableStream +from orcapod.core.sources import Source +from orcapod.core.streams import SyncStream +from orcapod.utils.name import pascal_to_snake, snake_to_pascal +from orcapod.utils.stream_utils import common_elements +from orcapod.dj.operation import QueryOperation +from orcapod.dj.stream import QueryStream, TableCachedStream, TableStream logger = logging.getLogger(__name__) diff --git a/src/orcapod/dj/stream.py b/src/orcapod/dj/stream.py index 3e4eb08..e8e7195 100644 --- a/src/orcapod/dj/stream.py +++ b/src/orcapod/dj/stream.py @@ -5,7 +5,7 @@ from datajoint.expression import QueryExpression from datajoint.table import Table -from orcapod.streams import SyncStream +from orcapod.core.streams import SyncStream logger = logging.getLogger(__name__) diff --git a/src/orcapod/dj/tracker.py b/src/orcapod/dj/tracker.py index b137e54..24df900 100644 --- a/src/orcapod/dj/tracker.py +++ b/src/orcapod/dj/tracker.py @@ -6,10 +6,10 @@ import networkx as nx from datajoint import Schema -from orcapod.base import Operation, Source -from orcapod.mappers import Mapper, Merge +from orcapod.core.base import Kernel, Source +from orcapod.core.operators import Operator, Merge from orcapod.pod import FunctionPod -from orcapod.pipeline import GraphTracker +from orcapod.core.tracker import GraphTracker from .mapper import convert_to_query_mapper from .operation import QueryOperation @@ -19,7 +19,7 @@ def convert_to_query_operation( - operation: Operation, + operation: Kernel, schema: Schema, table_name: str = None, table_postfix: str = "", @@ -68,7 +68,7 @@ def convert_to_query_operation( True, ) - if isinstance(operation, Mapper): + if isinstance(operation, Operator): return convert_to_query_mapper(operation), True # operation conversion is not supported, raise an error @@ -102,7 +102,7 @@ def generate_tables( for invocation in nx.topological_sort(G): streams = [edge_lut.get(stream, stream) for stream in invocation.streams] new_node, converted = convert_to_query_operation( - invocation.operation, + invocation.kernel, schema, table_name=None, table_postfix=invocation.content_hash_int(), diff --git a/src/orcapod/pod/core.py b/src/orcapod/pod/core.py index a82944f..5c3ee67 100644 --- a/src/orcapod/pod/core.py +++ b/src/orcapod/pod/core.py @@ -303,6 +303,10 @@ def generator() -> Iterator[tuple[Tag, Packet]]: elif self.error_handling == "warn": warnings.warn(f"Error processing packet {packet}: {e}") continue + else: + raise ValueError( + f"Unknown error handling mode: {self.error_handling} encountered while handling error:" + ) from e output_packet: Packet = { k: v for k, v in zip(self.output_keys, output_values) From ac228b02e22b64b5e17d095f6d78b58bcea3cc55 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 19 Jun 2025 20:48:35 +0000 Subject: [PATCH 02/57] feat: add ability to skip computation in pod --- src/orcapod/pod/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/orcapod/pod/core.py b/src/orcapod/pod/core.py index 5c3ee67..069e0de 100644 --- a/src/orcapod/pod/core.py +++ b/src/orcapod/pod/core.py @@ -186,6 +186,7 @@ def __init__( function_hash_mode: Literal["signature", "content", "name", "custom"] = "name", custom_hash: int | None = None, label: str | None = None, + skip_computation: bool = False, force_computation: bool = False, skip_memoization_lookup: bool = False, skip_memoization: bool = False, @@ -209,6 +210,7 @@ def __init__( self.store_name = store_name or function_name self.function_hash_mode = function_hash_mode self.custom_hash = custom_hash + self.skip_computation = skip_computation self.force_computation = force_computation self.skip_memoization_lookup = skip_memoization_lookup self.skip_memoization = skip_memoization @@ -277,6 +279,9 @@ def generator() -> Iterator[tuple[Tag, Packet]]: logger.info("Memoized packet found, skipping computation") yield tag, memoized_packet continue + if self.skip_computation: + logger.info("Skipping computation as per configuration") + continue values = self.function(**packet) if len(self.output_keys) == 0: From 450ec90aa8e0b8e33eac9433635bde99dcc5b763 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 24 Jun 2025 20:18:09 +0000 Subject: [PATCH 03/57] refactor: major change of structure and implementation of pipeline --- pyproject.toml | 2 +- src/orcapod/__init__.py | 4 +- src/orcapod/core/__init__.py | 13 + src/orcapod/core/base.py | 58 +- src/orcapod/core/operators.py | 48 +- src/orcapod/core/pod.py | 311 +++++++ src/orcapod/core/pod_legacy.py | 373 ++++++++ src/orcapod/core/tracker.py | 48 +- src/orcapod/dj/pod.py | 2 +- src/orcapod/dj/tracker.py | 2 +- src/orcapod/hashing/core.py | 35 +- src/orcapod/hashing/hashing_legacy.py | 269 ------ src/orcapod/hashing/types.py | 4 +- src/orcapod/pipeline/pipeline.py | 706 ++------------ src/orcapod/pipeline/wrappers.py | 667 +++++++++++++ src/orcapod/pod/__init__.py | 9 - src/orcapod/pod/core.py | 877 ------------------ src/orcapod/store/__init__.py | 5 +- src/orcapod/store/arrow_data_stores.py | 587 ++++++++++-- .../store/{core.py => dict_data_stores.py} | 0 src/orcapod/store/file.py | 159 ---- src/orcapod/store/file_ops.py | 158 +++- src/orcapod/store/optimized_memory_store.py | 433 +++++++++ .../{transfer.py => transfer_data_store.py} | 0 src/orcapod/store/types.py | 1 + src/orcapod/types/__init__.py | 79 +- src/orcapod/types/core.py | 49 +- src/orcapod/types/default.py | 18 - src/orcapod/types/registry.py | 10 +- .../types/{inference.py => typespec.py} | 9 +- uv.lock | 18 +- 31 files changed, 2773 insertions(+), 2181 deletions(-) create mode 100644 src/orcapod/core/pod.py create mode 100644 src/orcapod/core/pod_legacy.py delete mode 100644 src/orcapod/hashing/hashing_legacy.py create mode 100644 src/orcapod/pipeline/wrappers.py delete mode 100644 src/orcapod/pod/__init__.py delete mode 100644 src/orcapod/pod/core.py rename src/orcapod/store/{core.py => dict_data_stores.py} (100%) delete mode 100644 src/orcapod/store/file.py create mode 100644 src/orcapod/store/optimized_memory_store.py rename src/orcapod/store/{transfer.py => transfer_data_store.py} (100%) delete mode 100644 src/orcapod/types/default.py rename src/orcapod/types/{inference.py => typespec.py} (98%) diff --git a/pyproject.toml b/pyproject.toml index ca1c20c..aa23332 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pandas>=2.2.3", "pyyaml>=6.0.2", "pyarrow>=20.0.0", - "polars>=1.30.0", + "polars>=1.31.0", "beartype>=0.21.0", ] readme = "README.md" diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index a84492a..db457e9 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -1,8 +1,8 @@ from .core import operators, sources, streams from .core.streams import SyncStreamFromLists, SyncStreamFromGenerator -from . import hashing, pod, store +from . import hashing, store from .core.operators import Join, MapPackets, MapTags, packet, tag -from .pod import FunctionPod, function_pod +from .core.pod import FunctionPod, function_pod from .core.sources import GlobSource from .store import DirDataStore, SafeDirDataStore from .core.tracker import GraphTracker diff --git a/src/orcapod/core/__init__.py b/src/orcapod/core/__init__.py index e69de29..d236681 100644 --- a/src/orcapod/core/__init__.py +++ b/src/orcapod/core/__init__.py @@ -0,0 +1,13 @@ +from .base import Kernel, Invocation, Stream, SyncStream, Source +from .operators import Operator +from .pod import Pod + +__all__ = [ + "Kernel", + "Operator", + "Invocation", + "Stream", + "SyncStream", + "Source", + "Pod", +] diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 0b1ed63..664352d 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterator -from typing import Any, TypeVar, Hashable +from typing import Any from orcapod.hashing import HashableMixin @@ -27,9 +27,10 @@ class Kernel(ABC, HashableMixin): for computational graph tracking. """ - def __init__(self, label: str | None = None, **kwargs) -> None: + def __init__(self, label: str | None = None, skip_tracking: bool = False, **kwargs) -> None: super().__init__(**kwargs) self._label = label + self._skip_tracking = skip_tracking @property def label(self) -> str: @@ -40,29 +41,51 @@ def label(self) -> str: if self._label: return self._label return self.__class__.__name__ - + @label.setter def label(self, label: str) -> None: self._label = label + def pre_forward_hook( + self, *streams: "SyncStream", **kwargs + ) -> tuple["SyncStream", ...]: + """ + A hook that is called before the forward method is invoked. + This can be used to perform any pre-processing or validation on the input streams. + Subclasses can override this method to provide custom behavior. + """ + return streams + + def post_forward_hook(self, output_stream: "SyncStream", **kwargs) -> "SyncStream": + """ + A hook that is called after the forward method is invoked. + This can be used to perform any post-processing on the output stream. + Subclasses can override this method to provide custom behavior. + """ + return output_stream + + def __call__(self, *streams: "SyncStream", **kwargs) -> "SyncStream": # Special handling of Source: trigger call on source if passed as stream normalized_streams = [ stream() if isinstance(stream, Source) else stream for stream in streams ] - output_stream = self.forward(*normalized_streams, **kwargs) + pre_processed_streams = self.pre_forward_hook(*normalized_streams, **kwargs) + output_stream = self.forward(*pre_processed_streams, **kwargs) + post_processed_stream = self.post_forward_hook(output_stream, **kwargs) # create an invocation instance - invocation = Invocation(self, normalized_streams) + invocation = Invocation(self, pre_processed_streams) # label the output_stream with the invocation that produced the stream - output_stream.invocation = invocation + post_processed_stream.invocation = invocation - # register the invocation to all active trackers - active_trackers = Tracker.get_active_trackers() - for tracker in active_trackers: - tracker.record(invocation) + if not self._skip_tracking: + # register the invocation to all active trackers + active_trackers = Tracker.get_active_trackers() + for tracker in active_trackers: + tracker.record(invocation) - return output_stream + return post_processed_stream @abstractmethod def forward(self, *streams: "SyncStream") -> "SyncStream": @@ -98,7 +121,7 @@ def identity_structure(self, *streams: "SyncStream") -> Any: logger.warning( f"Identity structure not implemented for {self.__class__.__name__}" ) - return (self.__class__.__name__,) + tuple(streams) + return (self.__class__.__name__,) + streams def keys( self, *streams: "SyncStream", trigger_run: bool = False @@ -365,6 +388,8 @@ def keys( tag_keys, packet_keys = self.invocation.keys() if tag_keys is not None and packet_keys is not None: return tag_keys, packet_keys + if not trigger_run: + return None, None # otherwise, use the keys from the first packet in the stream # note that this may be computationally expensive tag, packet = next(iter(self)) @@ -386,6 +411,8 @@ def types(self, *, trigger_run=False) -> tuple[TypeSpec | None, TypeSpec | None] tag_types, packet_types = self.invocation.types() if not trigger_run or (tag_types is not None and packet_types is not None): return tag_types, packet_types + if not trigger_run: + return None, None # otherwise, use the keys from the first packet in the stream # note that this may be computationally expensive tag, packet = next(iter(self)) @@ -488,13 +515,6 @@ def claims_unique_tags(self, *, trigger_run=False) -> bool | None: return True -class Operator(Kernel): - """ - A Mapper is an operation that does NOT generate new file content. - It is used to control the flow of data in the pipeline without modifying or creating data content. - """ - - class Source(Kernel, SyncStream): """ A base class for all sources in the system. A source can be seen as a special diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index 093167b..84a31f3 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -3,9 +3,9 @@ from itertools import chain from typing import Any - -from orcapod.core.base import Operator, SyncStream +from orcapod.types import Packet, Tag, TypeSpec from orcapod.hashing import function_content_hash, hash_function +from orcapod.core.base import Kernel, SyncStream from orcapod.core.streams import SyncStreamFromGenerator from orcapod.utils.stream_utils import ( batch_packet, @@ -16,7 +16,12 @@ merge_typespecs, ) -from orcapod.types import Packet, Tag, TypeSpec + +class Operator(Kernel): + """ + A Mapper is an operation that does NOT generate new file content. + It is used to control the flow of data in the pipeline without modifying or creating data content. + """ class Repeat(Operator): @@ -186,12 +191,43 @@ def claims_unique_tags( return True +def union_lists(left, right): + if left is None or right is None: + return None + output = list(left) + for item in right: + if item not in output: + output.append(item) + return output + class Join(Operator): def identity_structure(self, *streams): # Join does not depend on the order of the streams -- convert it onto a set return (self.__class__.__name__, set(streams)) + def keys( + self, *streams: SyncStream, trigger_run=False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the types of the operation. + The first list contains the keys of the tags, and the second list contains the keys of the packets. + The keys are returned if it is feasible to do so, otherwise a tuple + (None, None) is returned to signify that the keys are not known. + """ + if len(streams) != 2: + raise ValueError("Join operation requires exactly two streams") + + left_stream, right_stream = streams + left_tag_keys, left_packet_keys = left_stream.keys(trigger_run=trigger_run) + right_tag_keys, right_packet_keys = right_stream.keys(trigger_run=trigger_run) + + # TODO: do error handling when merge fails + joined_tag_keys = union_lists(left_tag_keys, right_tag_keys) + joined_packet_keys = union_lists(left_packet_keys, right_packet_keys) + + return joined_tag_keys, joined_packet_keys + def types( self, *streams: SyncStream, trigger_run=False ) -> tuple[TypeSpec | None, TypeSpec | None]: @@ -225,8 +261,10 @@ def forward(self, *streams: SyncStream) -> SyncStream: left_stream, right_stream = streams def generator() -> Iterator[tuple[Tag, Packet]]: - for left_tag, left_packet in left_stream: - for right_tag, right_packet in right_stream: + left_stream_buffered = list(left_stream) + right_stream_buffered = list(right_stream) + for left_tag, left_packet in left_stream_buffered: + for right_tag, right_packet in right_stream_buffered: if (joined_tag := join_tags(left_tag, right_tag)) is not None: if not check_packet_compatibility(left_packet, right_packet): raise ValueError( diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py new file mode 100644 index 0000000..582fa85 --- /dev/null +++ b/src/orcapod/core/pod.py @@ -0,0 +1,311 @@ +import logging +import warnings +import sys +from collections.abc import Callable, Collection, Iterable, Iterator, Sequence +from typing import ( + Any, + Literal, +) + +from orcapod.types import Packet, Tag, TypeSpec, default_registry +from orcapod.types.typespec import extract_function_typespecs +from orcapod.types.registry import PacketConverter + +from orcapod.hashing import ( + FunctionInfoExtractor, + get_function_signature, +) +from orcapod.core import Kernel +from orcapod.core.operators import Join +from orcapod.core.streams import ( + SyncStream, + SyncStreamFromGenerator, +) + +logger = logging.getLogger(__name__) + + +class Pod(Kernel): + """ + An (abstract) base class for all pods. A pod can be seen as a special type of operation that + only operates on the packet content without reading tags. Consequently, no operation + of Pod can dependent on the tags of the packets. This is a design choice to ensure that + the pods act as pure functions which is a necessary condition to guarantee reproducibility. + """ + + def __init__( + self, error_handling: Literal["raise", "ignore", "warn"] = "raise", **kwargs + ): + super().__init__(**kwargs) + self._active = True + self.error_handling = error_handling + + def is_active(self) -> bool: + """ + Check if the pod is active. If not, it will not process any packets. + """ + return self._active + + def set_active(self, active: bool) -> None: + """ + Set the active state of the pod. If set to False, the pod will not process any packets. + """ + self._active = active + + + def process_stream(self, *streams: SyncStream) -> tuple[SyncStream, ...]: + """ + Prepare the incoming streams for execution in the pod. This default implementation + joins all the input streams together. + """ + # if multiple streams are provided, join them + # otherwise, return as is + combined_streams = list(streams) + if len(streams) > 1: + stream = streams[0] + for next_stream in streams[1:]: + stream = Join()(stream, next_stream) + combined_streams = [stream] + return tuple(combined_streams) + + def pre_forward_hook( + self, *streams: SyncStream, **kwargs + ) -> tuple[SyncStream, ...]: + return self.process_stream(*streams) + + def generator_completion_hook(self, n_computed: int) -> None: + """ + Hook that is called when the generator is completed. This can be used to + perform any finalization steps, such as closing resources or logging. + """ + logger.debug(f"Generator completed with {n_computed} items processed.") + + def forward(self, *streams: SyncStream) -> SyncStream: + # at this point, streams should have been joined into one + assert len(streams) == 1, "Only one stream is supported in forward() of Pod" + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + n_computed = 0 + for tag, packet in stream: + try: + tag, output_packet = self.call(tag, packet) + if output_packet is None: + logger.debug( + f"Call returned None as output for tag {tag}. Skipping..." + ) + continue + n_computed += 1 + logger.debug(f"Computed item {n_computed}") + yield tag, output_packet + + except Exception as e: + logger.error(f"Error processing packet {packet}: {e}") + if self.error_handling == "raise": + raise e + elif self.error_handling == "warn": + warnings.warn(f"Error processing packet {packet}: {e}") + continue + elif self.error_handling == "ignore": + continue + else: + raise ValueError( + f"Unknown error handling mode: {self.error_handling} encountered while handling error:" + ) from e + self.generator_completion_hook(n_computed) + + return SyncStreamFromGenerator(generator) + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: ... + + +def function_pod( + output_keys: str | Collection[str] | None = None, + function_name: str | None = None, + label: str | None = None, + **kwargs, +) -> Callable[..., "FunctionPod"]: + """ + Decorator that wraps a function in a FunctionPod instance. + + Args: + output_keys: Keys for the function output(s) + function_name: Name of the function pod; if None, defaults to the function name + **kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details. + + Returns: + FunctionPod instance wrapping the decorated function + """ + + def decorator(func) -> FunctionPod: + if func.__name__ == "": + raise ValueError("Lambda functions cannot be used with function_pod") + + if not hasattr(func, "__module__") or func.__module__ is None: + raise ValueError( + f"Function {func.__name__} must be defined at module level" + ) + + # Store the original function in the module for pickling purposes + # and make sure to change the name of the function + module = sys.modules[func.__module__] + base_function_name = func.__name__ + new_function_name = f"_original_{func.__name__}" + setattr(module, new_function_name, func) + # rename the function to be consistent and make it pickleable + setattr(func, "__name__", new_function_name) + setattr(func, "__qualname__", new_function_name) + + # Create a simple typed function pod + pod = FunctionPod( + function=func, + output_keys=output_keys, + function_name=function_name or base_function_name, + label=label, + **kwargs, + ) + return pod + + return decorator + + +class FunctionPod(Pod): + def __init__( + self, + function: Callable[..., Any], + output_keys: str | Collection[str] | None = None, + function_name=None, + input_types: TypeSpec | None = None, + output_types: TypeSpec | Sequence[type] | None = None, + label: str | None = None, + packet_type_registry=None, + function_info_extractor: FunctionInfoExtractor | None = None, + **kwargs, + ) -> None: + self.function = function + if output_keys is None: + output_keys = [] + if isinstance(output_keys, str): + output_keys = [output_keys] + self.output_keys = output_keys + if function_name is None: + if hasattr(self.function, "__name__"): + function_name = getattr(self.function, "__name__") + else: + raise ValueError( + "function_name must be provided if function has no __name__ attribute" + ) + self.function_name = function_name + super().__init__(label=label or self.function_name, **kwargs) + + if packet_type_registry is None: + # TODO: reconsider the use of default registry here + packet_type_registry = default_registry + + self.registry = packet_type_registry + self.function_info_extractor = function_info_extractor + + # extract input and output types from the function signature + self.function_input_typespec, self.function_output_typespec = ( + extract_function_typespecs( + self.function, + self.output_keys, + input_types=input_types, + output_types=output_types, + ) + ) + + self.input_converter = PacketConverter(self.function_input_typespec, self.registry) + self.output_converter = PacketConverter( + self.function_output_typespec, self.registry + ) + + def get_function_typespecs(self) -> tuple[TypeSpec, TypeSpec]: + return self.function_input_typespec, self.function_output_typespec + + + def __repr__(self) -> str: + return f"FunctionPod:{self.function!r}" + + def __str__(self) -> str: + func_sig = get_function_signature(self.function) + return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" + + def call(self, tag, packet) -> tuple[Tag, Packet | None]: + if not self.is_active(): + logger.info( + f"Pod is not active: skipping computation on input packet {packet}" + ) + return tag, None + output_values = [] + + values = self.function(**packet) + + if len(self.output_keys) == 0: + output_values = [] + elif len(self.output_keys) == 1: + output_values = [values] # type: ignore + elif isinstance(values, Iterable): + output_values = list(values) # type: ignore + elif len(self.output_keys) > 1: + raise ValueError( + "Values returned by function must be a pathlike or a sequence of pathlikes" + ) + + if len(output_values) != len(self.output_keys): + raise ValueError( + f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" + ) + + output_packet: Packet = {k: v for k, v in zip(self.output_keys, output_values)} + return tag, output_packet + + def identity_structure(self, *streams) -> Any: + # construct identity structure for the function + # if function_info_extractor is available, use that but substitute the function_name + if self.function_info_extractor is not None: + function_info = self.function_info_extractor.extract_function_info( + self.function, + function_name=self.function_name, + input_typespec=self.function_input_typespec, + output_typespec=self.function_output_typespec, + ) + else: + # use basic information only + function_info = { + "name": self.function_name, + "input_typespec": self.function_input_typespec, + "output_typespec": self.function_output_typespec, + } + function_info["output_keys"] = tuple(self.output_keys) + + return ( + self.__class__.__name__, + function_info, + ) + streams + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + stream = self.process_stream(*streams) + if len(stream) < 1: + tag_keys = None + else: + tag_keys, _ = stream[0].keys(trigger_run=trigger_run) + return tag_keys, tuple(self.output_keys) + + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + stream = self.process_stream(*streams) + if len(stream) < 1: + tag_typespec = None + else: + tag_typespec, _ = stream[0].types(trigger_run=trigger_run) + return tag_typespec, self.function_output_typespec + + def claims_unique_tags( + self, *streams: SyncStream, trigger_run: bool = False + ) -> bool | None: + stream = self.process_stream(*streams) + return stream[0].claims_unique_tags(trigger_run=trigger_run) diff --git a/src/orcapod/core/pod_legacy.py b/src/orcapod/core/pod_legacy.py new file mode 100644 index 0000000..32c8efb --- /dev/null +++ b/src/orcapod/core/pod_legacy.py @@ -0,0 +1,373 @@ +import logging +import warnings +import sys +from collections.abc import Callable, Collection, Iterable, Iterator +from typing import ( + Any, + Literal, +) + +from orcapod.types import Packet, PathSet, PodFunction, Tag + +from orcapod.hashing import ( + get_function_signature, + hash_function, +) +from orcapod.core.base import Kernel +from orcapod.core.operators import Join +from orcapod.core.streams import SyncStream, SyncStreamFromGenerator +from orcapod.store import DataStore, NoOpDataStore + + +logger = logging.getLogger(__name__) + + +class Pod(Kernel): + """ + An (abstract) base class for all pods. A pod can be seen as a special type of operation that + only operates on the packet content without reading tags. Consequently, no operation + of Pod can dependent on the tags of the packets. This is a design choice to ensure that + the pods act as pure functions which is a necessary condition to guarantee reproducibility. + """ + + def __init__( + self, error_handling: Literal["raise", "ignore", "warn"] = "raise", **kwargs + ): + super().__init__(**kwargs) + self.error_handling = error_handling + self._active = True + + def set_active(self, active=True): + self._active = active + + def is_active(self) -> bool: + return self._active + + def process_stream(self, *streams: SyncStream) -> tuple[SyncStream, ...]: + """ + Prepare the incoming streams for execution in the pod. This default implementation + joins all the streams together and raises and error if no streams are provided. + """ + # if multiple streams are provided, join them + # otherwise, return as is + combined_streams = list(streams) + if len(streams) > 1: + stream = streams[0] + for next_stream in streams[1:]: + stream = Join()(stream, next_stream) + combined_streams = [stream] + return tuple(combined_streams) + + def pre_forward_hook( + self, *streams: SyncStream, **kwargs + ) -> tuple[SyncStream, ...]: + return self.process_stream(*streams) + + def forward(self, *streams: SyncStream) -> SyncStream: + # if multiple streams are provided, join them + if len(streams) > 1: + raise ValueError("Multiple streams should be joined before calling forward") + if len(streams) == 0: + raise ValueError("No streams provided to forward") + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + n_computed = 0 + for tag, packet in stream: + try: + tag, output_packet = self.call(tag, packet) + if output_packet is None: + logger.info( + f"Call returned None as output for tag {tag}. Skipping..." + ) + continue + n_computed += 1 + logger.info(f"Computed item {n_computed}") + yield tag, output_packet + + except Exception as e: + logger.error(f"Error processing packet {packet}: {e}") + if self.error_handling == "raise": + raise e + elif self.error_handling == "warn": + warnings.warn(f"Error processing packet {packet}: {e}") + continue + elif self.error_handling == "ignore": + continue + else: + raise ValueError( + f"Unknown error handling mode: {self.error_handling} encountered while handling error:" + ) from e + + return SyncStreamFromGenerator(generator) + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: ... + + +def function_pod( + output_keys: Collection[str] | None = None, + function_name: str | None = None, + data_store: DataStore | None = None, + store_name: str | None = None, + function_hash_mode: Literal["signature", "content", "name", "custom"] = "name", + custom_hash: int | None = None, + force_computation: bool = False, + skip_memoization: bool = False, + error_handling: Literal["raise", "ignore", "warn"] = "raise", + **kwargs, +) -> Callable[..., "FunctionPod"]: + """ + Decorator that wraps a function in a FunctionPod instance. + + Args: + output_keys: Keys for the function output + force_computation: Whether to force computation + skip_memoization: Whether to skip memoization + + Returns: + FunctionPod instance wrapping the decorated function + """ + + def decorator(func) -> FunctionPod: + if func.__name__ == "": + raise ValueError("Lambda functions cannot be used with function_pod") + + if not hasattr(func, "__module__") or func.__module__ is None: + raise ValueError( + f"Function {func.__name__} must be defined at module level" + ) + + # Store the original function in the module for pickling purposes + # and make sure to change the name of the function + module = sys.modules[func.__module__] + base_function_name = func.__name__ + new_function_name = f"_original_{func.__name__}" + setattr(module, new_function_name, func) + # rename the function to be consistent and make it pickleable + setattr(func, "__name__", new_function_name) + setattr(func, "__qualname__", new_function_name) + + # Create the FunctionPod + pod = FunctionPod( + function=func, + output_keys=output_keys, + function_name=function_name or base_function_name, + data_store=data_store, + store_name=store_name, + function_hash_mode=function_hash_mode, + custom_hash=custom_hash, + force_computation=force_computation, + skip_memoization=skip_memoization, + error_handling=error_handling, + **kwargs, + ) + + return pod + + return decorator + + +class FunctionPod(Pod): + """ + A pod that wraps a function and allows it to be used as an operation in a stream. + This pod can be used to apply a function to the packets in a stream, with optional memoization + and caching of results. It can also handle multiple output keys and error handling. + The function should accept keyword arguments that correspond to the keys in the packets. + The output of the function should be a path or a collection of paths that correspond to the output keys.""" + + def __init__( + self, + function: PodFunction, + output_keys: Collection[str] | None = None, + function_name=None, + data_store: DataStore | None = None, + store_name: str | None = None, + function_hash_mode: Literal["signature", "content", "name", "custom"] = "name", + custom_hash: int | None = None, + label: str | None = None, + force_computation: bool = False, + skip_memoization_lookup: bool = False, + skip_memoization: bool = False, + error_handling: Literal["raise", "ignore", "warn"] = "raise", + _hash_function_kwargs: dict | None = None, + **kwargs, + ) -> None: + super().__init__(label=label, **kwargs) + self.function = function + self.output_keys = output_keys or [] + if function_name is None: + if hasattr(self.function, "__name__"): + function_name = getattr(self.function, "__name__") + else: + raise ValueError( + "function_name must be provided if function has no __name__ attribute" + ) + + self.function_name = function_name + self.data_store = data_store if data_store is not None else NoOpDataStore() + self.store_name = store_name or function_name + self.function_hash_mode = function_hash_mode + self.custom_hash = custom_hash + self.force_computation = force_computation + self.skip_memoization_lookup = skip_memoization_lookup + self.skip_memoization = skip_memoization + self.error_handling = error_handling + self._hash_function_kwargs = _hash_function_kwargs + + def __repr__(self) -> str: + func_sig = get_function_signature(self.function) + return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + stream = self.process_stream(*streams) + tag_keys, _ = stream[0].keys(trigger_run=trigger_run) + return tag_keys, tuple(self.output_keys) + + def is_memoized(self, packet: Packet) -> bool: + return self.retrieve_memoized(packet) is not None + + def retrieve_memoized(self, packet: Packet) -> Packet | None: + """ + Retrieve a memoized packet from the data store. + Returns None if no memoized packet is found. + """ + return self.data_store.retrieve_memoized( + self.store_name, + self.content_hash(char_count=16), + packet, + ) + + def memoize( + self, + packet: Packet, + output_packet: Packet, + ) -> Packet: + """ + Memoize the output packet in the data store. + Returns the memoized packet. + """ + return self.data_store.memoize( + self.store_name, + self.content_hash(char_count=16), # identity of this function pod + packet, + output_packet, + ) + + def forward(self, *streams: SyncStream) -> SyncStream: + # if multiple streams are provided, join them + if len(streams) > 1: + raise ValueError("Multiple streams should be joined before calling forward") + if len(streams) == 0: + raise ValueError("No streams provided to forward") + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + n_computed = 0 + for tag, packet in stream: + output_values: list["PathSet"] = [] + try: + if not self.skip_memoization_lookup: + memoized_packet = self.retrieve_memoized(packet) + else: + memoized_packet = None + if not self.force_computation and memoized_packet is not None: + logger.info("Memoized packet found, skipping computation") + yield tag, memoized_packet + continue + if not self.is_active(): + logger.info( + "Pod is not active: skipping computation of a new entry" + ) + continue + values = self.function(**packet) + + if len(self.output_keys) == 0: + output_values = [] + elif len(self.output_keys) == 1: + output_values = [values] # type: ignore + elif isinstance(values, Iterable): + output_values = list(values) # type: ignore + elif len(self.output_keys) > 1: + raise ValueError( + "Values returned by function must be a pathlike or a sequence of pathlikes" + ) + + if len(output_values) != len(self.output_keys): + raise ValueError( + f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" + ) + except Exception as e: + logger.error(f"Error processing packet {packet}: {e}") + if self.error_handling == "raise": + raise e + elif self.error_handling == "ignore": + continue + elif self.error_handling == "warn": + warnings.warn(f"Error processing packet {packet}: {e}") + continue + else: + raise ValueError( + f"Unknown error handling mode: {self.error_handling} encountered while handling error:" + ) from e + + output_packet: Packet = { + k: v for k, v in zip(self.output_keys, output_values) + } + + if not self.skip_memoization: + # output packet may be modified by the memoization process + # e.g. if the output is a file, the path may be changed + output_packet = self.memoize(packet, output_packet) # type: ignore + + n_computed += 1 + logger.info(f"Computed item {n_computed}") + yield tag, output_packet + + return SyncStreamFromGenerator(generator) + + def identity_structure(self, *streams) -> Any: + content_kwargs = self._hash_function_kwargs + if self.function_hash_mode == "content": + if content_kwargs is None: + content_kwargs = { + "include_name": False, + "include_module": False, + "include_declaration": False, + } + function_hash_value = hash_function( + self.function, + name_override=self.function_name, + function_hash_mode="content", + content_kwargs=content_kwargs, + ) + elif self.function_hash_mode == "signature": + function_hash_value = hash_function( + self.function, + name_override=self.function_name, + function_hash_mode="signature", + content_kwargs=content_kwargs, + ) + elif self.function_hash_mode == "name": + function_hash_value = hash_function( + self.function, + name_override=self.function_name, + function_hash_mode="name", + content_kwargs=content_kwargs, + ) + elif self.function_hash_mode == "custom": + if self.custom_hash is None: + raise ValueError("Custom hash function not provided") + function_hash_value = self.custom_hash + else: + raise ValueError( + f"Unknown function hash mode: {self.function_hash_mode}. " + "Must be one of 'content', 'signature', 'name', or 'custom'." + ) + + return ( + self.__class__.__name__, + function_hash_value, + tuple(self.output_keys), + ) + tuple(streams) diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py index efc2c42..2532582 100644 --- a/src/orcapod/core/tracker.py +++ b/src/orcapod/core/tracker.py @@ -1,5 +1,39 @@ -from orcapod.core.base import Invocation, Kernel, Tracker +from orcapod.core.base import Invocation, Kernel, Tracker, SyncStream, TypeSpec +from collections.abc import Collection +from typing import Any +class StubKernel(Kernel): + def __init__(self, stream: SyncStream, **kwargs): + super().__init__(skip_tracking=True, **kwargs) + self.stream = stream + + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "StubKernel does not support forwarding streams. " + "It generates its own stream from the file system." + ) + return self.stream + + def identity_structure(self, *streams) -> Any: + if len(streams) != 0: + raise ValueError( + "StubKernel does not support forwarding streams. " + "It generates its own stream from the file system." + ) + + return (self.__class__.__name__, self.stream) + + def types(self, *streams: SyncStream, **kwargs) -> tuple[TypeSpec|None, TypeSpec|None]: + return self.stream.types() + + def keys(self, *streams: SyncStream, **kwargs) -> tuple[Collection[str]|None, Collection[str]|None]: + return self.stream.keys() + + + + + class GraphTracker(Tracker): """ @@ -44,6 +78,7 @@ def generate_namemap(self) -> dict[Invocation, str]: def generate_graph(self): import networkx as nx + G = nx.DiGraph() # Add edges for each invocation @@ -51,15 +86,20 @@ def generate_graph(self): for invocation in invocations: for upstream in invocation.streams: # if upstream.invocation is not in the graph, add it - if upstream.invocation not in G: - G.add_node(upstream.invocation) - G.add_edge(upstream.invocation, invocation, stream=upstream) + upstream_invocation = upstream.invocation + if upstream_invocation is None: + # If upstream is None, create a stub kernel + upstream_invocation = Invocation(StubKernel(upstream, label="StubInput"), []) + if upstream_invocation not in G: + G.add_node(upstream_invocation) + G.add_edge(upstream_invocation, invocation, stream=upstream) return G def draw_graph(self): import networkx as nx import matplotlib.pyplot as plt + G = self.generate_graph() labels = self.generate_namemap() diff --git a/src/orcapod/dj/pod.py b/src/orcapod/dj/pod.py index 815b2dc..7101090 100644 --- a/src/orcapod/dj/pod.py +++ b/src/orcapod/dj/pod.py @@ -5,7 +5,7 @@ from datajoint import Schema from datajoint.table import Table -from ..pod import FunctionPod, Pod +from orcapod.core.pod import FunctionPod, Pod from ..utils.name import pascal_to_snake, snake_to_pascal from .mapper import JoinQuery from .operation import QueryOperation diff --git a/src/orcapod/dj/tracker.py b/src/orcapod/dj/tracker.py index 24df900..3276ba9 100644 --- a/src/orcapod/dj/tracker.py +++ b/src/orcapod/dj/tracker.py @@ -8,7 +8,7 @@ from orcapod.core.base import Kernel, Source from orcapod.core.operators import Operator, Merge -from orcapod.pod import FunctionPod +from orcapod.core.pod import FunctionPod from orcapod.core.tracker import GraphTracker from .mapper import convert_to_query_mapper diff --git a/src/orcapod/hashing/core.py b/src/orcapod/hashing/core.py index c711f63..66b4e4d 100644 --- a/src/orcapod/hashing/core.py +++ b/src/orcapod/hashing/core.py @@ -5,7 +5,7 @@ A library for creating stable, content-based hashes that remain consistent across Python sessions, suitable for arbitrarily nested data structures and custom objects via HashableMixin. """ - +WARN_NONE_IDENTITY=False import hashlib import inspect import json @@ -175,11 +175,12 @@ def content_hash(self, char_count: Optional[int] = 16) -> str: # If no custom structure is provided, use the class name # We avoid using id() since it's not stable across sessions if structure is None: - logger.warning( - f"HashableMixin.content_hash called on {self.__class__.__name__} " - "instance that returned identity_structure() of None. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) + if WARN_NONE_IDENTITY: + logger.warning( + f"HashableMixin.content_hash called on {self.__class__.__name__} " + "instance that returned identity_structure() of None. " + "Using class name as default identity, which may not correctly reflect object uniqueness." + ) # Fall back to class name for consistent behavior return f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" @@ -205,11 +206,12 @@ def content_hash_int(self, hexdigits: int = 16) -> int: # If no custom structure is provided, use the class name # We avoid using id() since it's not stable across sessions if structure is None: - logger.warning( - f"HashableMixin.content_hash_int called on {self.__class__.__name__} " - "instance without identity_structure() implementation. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) + if WARN_NONE_IDENTITY: + logger.warning( + f"HashableMixin.content_hash_int called on {self.__class__.__name__} " + "instance that returned identity_structure() of None. " + "Using class name as default identity, which may not correctly reflect object uniqueness." + ) # Use the same default identity as content_hash for consistency default_identity = ( f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" @@ -235,11 +237,12 @@ def content_hash_uuid(self) -> UUID: # If no custom structure is provided, use the class name # We avoid using id() since it's not stable across sessions if structure is None: - logger.warning( - f"HashableMixin.content_hash_uuid called on {self.__class__.__name__} " - "instance without identity_structure() implementation. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) + if WARN_NONE_IDENTITY: + logger.warning( + f"HashableMixin.content_hash_uuid called on {self.__class__.__name__} " + "instance without identity_structure() implementation. " + "Using class name as default identity, which may not correctly reflect object uniqueness." + ) # Use the same default identity as content_hash for consistency default_identity = ( f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" diff --git a/src/orcapod/hashing/hashing_legacy.py b/src/orcapod/hashing/hashing_legacy.py deleted file mode 100644 index 353a4f9..0000000 --- a/src/orcapod/hashing/hashing_legacy.py +++ /dev/null @@ -1,269 +0,0 @@ -# # a function to hash a dictionary of key value pairs into uuid -# from collections.abc import Collection, Mapping -# import hashlib -# import uuid -# from uuid import UUID -# from typing import Any, Dict, Optional, Union -# import inspect -# import json - -# import hashlib - -# # arbitrary depth of nested dictionaries -# T = Dict[str, Union[str, "T"]] - - -# # TODO: implement proper recursive hashing - - -# def hash_dict(d: T) -> UUID: -# # Convert the dictionary to a string representation -# dict_str = str(sorted(d.items())) - -# # Create a hash of the string representation -# hash_object = hashlib.sha256(dict_str.encode("utf-8")) - -# # Convert the hash to a UUID -# hash_uuid = uuid.UUID(hash_object.hexdigest()) - -# return hash_uuid - - -# def stable_hash(s): -# """Create a stable hash that returns the same integer value across sessions.""" -# # Convert input to bytes if it's not already -# if not isinstance(s, bytes): -# s = str(s).encode("utf-8") - -# hash_hex = hashlib.sha256(s).hexdigest() -# return int(hash_hex[:16], 16) - - -# def hash_function(function, function_hash_mode: str = "content", hasher_kwargs=None) -> str: -# """ -# Hash a function based on its content, signature, or name. -# -# Args: -# function: The function to hash -# function_hash_mode: The mode of hashing ('content', 'signature', 'name') -# function_name: Optional name for the function (if not provided, uses function's __name__) - -# Returns: -# A string representing the hash of the function -# """ -# if hasher_kwargs is None: -# hasher_kwargs = {} - -# if function_hash_mode == "content": -# function_hash = function_content_hash(function, **hasher_kwargs) -# elif function_hash_mode == "signature": -# function_hash = stable_hash(get_function_signature(function, **hasher_kwargs)) -# elif function_hash_mode == "name": -# function_hash = stable_hash(function.__name__) - -# return function_hash - - -# def function_content_hash( -# func, exclude_name=False, exclude_module=False, exclude_declaration=False, return_components=False -# ): -# """ -# Compute a hash based on the function's source code, name, module, and closure variables. -# """ -# components = [] - -# # Add function name -# if not exclude_name: -# components.append(f"name:{func.__name__}") - -# # Add module -# if not exclude_module: -# components.append(f"module:{func.__module__}") - -# # Get the function's source code -# try: -# source = inspect.getsource(func) -# # Clean up the source code -# source = source.strip() -# # Remove the function definition line -# if exclude_declaration: -# # find the line that starts with def and remove it -# # TODO: consider dealing with more sophisticated cases like decorators -# source = "\n".join(line for line in source.split("\n") if not line.startswith("def ")) -# components.append(f"source:{source}") -# except (IOError, TypeError): -# # If we can't get the source (e.g., built-in function), use the function's string representation -# components.append(f"repr:{repr(func)}") - -# # Add closure variables if any -# if func.__closure__: -# closure_values = [] -# for cell in func.__closure__: -# # Try to get a stable representation of the cell content -# try: -# # For simple immutable objects -# if isinstance(cell.cell_contents, (int, float, str, bool, type(None))): -# closure_values.append(repr(cell.cell_contents)) -# # For other objects, we'll use their string representation -# else: -# closure_values.append(str(cell.cell_contents)) -# except: -# # If we can't get a stable representation, use the cell's id -# closure_values.append(f"cell_id:{id(cell)}") - -# components.append(f"closure:{','.join(closure_values)}") - -# # Add function attributes that affect behavior -# if hasattr(func, "__defaults__") and func.__defaults__: -# defaults_str = ",".join(repr(d) for d in func.__defaults__) -# components.append(f"defaults:{defaults_str}") - -# if hasattr(func, "__kwdefaults__") and func.__kwdefaults__: -# kwdefaults_str = ",".join(f"{k}={repr(v)}" for k, v in func.__kwdefaults__.items()) -# components.append(f"kwdefaults:{kwdefaults_str}") - -# # Function's code object properties (excluding filename and line numbers) -# code = func.__code__ -# code_props = { -# "co_argcount": code.co_argcount, -# "co_posonlyargcount": getattr(code, "co_posonlyargcount", 0), # Python 3.8+ -# "co_kwonlyargcount": code.co_kwonlyargcount, -# "co_nlocals": code.co_nlocals, -# "co_stacksize": code.co_stacksize, -# "co_flags": code.co_flags, -# "co_code": code.co_code, -# "co_names": code.co_names, -# "co_varnames": code.co_varnames, -# } -# components.append(f"code_properties:{repr(code_props)}") -# if return_components: -# return components - -# # Join all components and compute hash -# combined = "\n".join(components) -# return hashlib.sha256(combined.encode("utf-8")).hexdigest() - - -# class HashableMixin: -# """A mixin that provides content-based hashing functionality.""" - -# def identity_structure(self) -> Any: -# """ -# Return a structure that represents the identity of this object. -# By default, returns None to indicate that no custom structure is provided. -# Subclasses should override this method to provide meaningful representations. - -# Returns: -# None to indicate no custom structure (use default hash) -# """ -# return None - -# def content_hash(self, char_count: Optional[int] = 16) -> str: -# """ -# Generate a stable string hash based on the object's content. - -# Returns: -# str: A hexadecimal digest representing the object's content -# """ -# # Get the identity structure -# structure = self.identity_structure() - -# # TODO: consider returning __hash__ based value if structure is None - -# # Generate a hash from the identity structure -# return self._hash_structure(structure, char_count=char_count) - -# def content_hash_int(self, hexdigits=16) -> int: -# """ -# Generate a stable integer hash based on the object's content. - -# Returns: -# int: An integer representing the object's content -# """ -# return int(self.content_hash(char_count=None)[:hexdigits], 16) - -# def __hash__(self) -> int: -# """ -# Hash implementation that uses the identity structure if provided, -# otherwise falls back to the superclass's hash method. - -# Returns: -# int: A hash value based on either content or identity -# """ -# # Get the identity structure -# structure = self.identity_structure() - -# # If no custom structure is provided, use the superclass's hash -# if structure is None: -# return super().__hash__() - -# # Generate a hash and convert to integer -# hash_hex = self._hash_structure(structure, char_count=None) -# return int(hash_hex[:16], 16) - -# def _hash_structure(self, structure: Any, char_count: Optional[int] = 16) -> str: -# """ -# Helper method to compute a hash string from a structure. - -# Args: -# structure: The structure to hash - -# Returns: -# str: A hexadecimal hash digest of the structure -# """ -# processed = self._process_structure(structure) -# json_str = json.dumps(processed, sort_keys=True).encode() -# return hashlib.sha256(json_str).hexdigest()[:char_count] - -# def _process_structure(self, obj: Any) -> Any: -# """ -# Recursively process a structure to prepare it for hashing. - -# Args: -# obj: The object or structure to process - -# Returns: -# A processed version of the structure with HashableMixin objects replaced by their hashes -# """ -# # Handle None -# if obj is None: -# return "None" - -# # If the object is a HashableMixin, use its content_hash -# if isinstance(obj, HashableMixin): -# # Don't call content_hash on self to avoid cycles -# if obj is self: -# # TODO: carefully consider this case -# # Use the superclass's hash for self -# return str(super(HashableMixin, self).__hash__()) -# return obj.content_hash() - -# # Handle basic types -# if isinstance(obj, (str, int, float, bool)): -# return str(obj) - -# # Handle named tuples (which are subclasses of tuple) -# if hasattr(obj, "_fields") and isinstance(obj, tuple): -# # For namedtuples, convert to dict and then process -# return self._process_structure({field: value for field, value in zip(obj._fields, obj)}) - -# # Handle mappings (dict-like objects) -# if isinstance(obj, Mapping): -# return {str(k): self._process_structure(v) for k, v in sorted(obj.items(), key=lambda x: str(x[0]))} - -# # Handle sets and frozensets specifically -# if isinstance(obj, (set, frozenset)): -# # Process each item first, then sort the processed results -# processed_items = [self._process_structure(item) for item in obj] -# return sorted(processed_items, key=str) - -# # Handle collections (list-like objects) -# if isinstance(obj, Collection): -# return [self._process_structure(item) for item in obj] - -# # For bytes and bytearray, convert to hex representation -# if isinstance(obj, (bytes, bytearray)): -# return obj.hex() - -# # For other objects, just use their string representation -# return str(obj) diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index 5e8b07c..36155bb 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -137,6 +137,6 @@ def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_types: TypeSpec | None = None, - output_types: TypeSpec | None = None, + input_typespec: TypeSpec | None = None, + output_typespec: TypeSpec | None = None, ) -> dict[str, Any]: ... diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index f160f2b..5df050f 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -1,3 +1,5 @@ +from collections import defaultdict +from collections.abc import Collection, Iterator import json import logging import pickle @@ -7,12 +9,21 @@ from pathlib import Path from typing import Any, Protocol, runtime_checkable -import networkx as nx import pandas as pd -from orcapod.core.base import Invocation, Kernel +from orcapod.core import Invocation, Kernel, SyncStream +from orcapod.core.pod import FunctionPod +from orcapod.pipeline.wrappers import KernelNode, FunctionPodNode, Node + from orcapod.hashing import hash_to_hex from orcapod.core.tracker import GraphTracker +from orcapod.hashing import ObjectHasher, ArrowHasher +from orcapod.types import TypeSpec, Tag, Packet +from orcapod.core.streams import SyncStreamFromGenerator +from orcapod.store import ArrowDataStore +from orcapod.types.registry import PacketConverter, TypeRegistry +from orcapod.types import default_registry +from orcapod.utils.stream_utils import merge_typespecs, get_typespec logger = logging.getLogger(__name__) @@ -29,12 +40,12 @@ class Pipeline(GraphTracker): Replaces the old Tracker with better persistence and view capabilities. """ - def __init__(self, name: str | None = None): + def __init__(self, name: str, results_store: ArrowDataStore, pipeline_store: ArrowDataStore) -> None: super().__init__() self.name = name or f"pipeline_{id(self)}" - self._view_registry: dict[str, "PipelineView"] = {} - self._cache_dir = Path(".pipeline_cache") / self.name - self._cache_dir.mkdir(parents=True, exist_ok=True) + self.results_store = results_store + self.pipeline_store = pipeline_store + self.labels_to_nodes = {} # Core Pipeline Operations def save(self, path: Path | str) -> None: @@ -66,6 +77,62 @@ def save(self, path: Path | str) -> None: temp_path.unlink() raise + def wrap_invocation( + self, kernel: Kernel, input_nodes: Collection[Node] + ) -> Node: + if isinstance(kernel, FunctionPod): + return FunctionPodNode(kernel, input_nodes, output_store=self.results_store, tag_store=self.pipeline_store) + return KernelNode(kernel, input_nodes, output_store=self.pipeline_store) + + def compile(self): + import networkx as nx + G = self.generate_graph() + + # Proposed labels for each Kernel in the graph + # If name collides, unique name is generated by appending an index + proposed_labels = defaultdict(list) + node_lut = {} + edge_lut : dict[SyncStream, Node]= {} + for invocation in nx.topological_sort(G): + # map streams to the new streams based on Nodes + input_nodes = [edge_lut[stream] for stream in invocation.streams] + new_node = self.wrap_invocation(invocation.kernel, input_nodes) + + # register the new node against the original invocation + node_lut[invocation] = new_node + # register the new node in the proposed labels -- if duplicates occur, will resolve later + proposed_labels[new_node.label].append(new_node) + + for edge in G.out_edges(invocation): + edge_lut[G.edges[edge]["stream"]] = new_node + + # resolve duplicates in proposed_labels + labels_to_nodes = {} + for label, nodes in proposed_labels.items(): + if len(nodes) > 1: + # If multiple nodes have the same label, append index to make it unique + for idx, node in enumerate(nodes): + node.label = f"{label}_{idx}" + labels_to_nodes[node.label] = node + else: + # If only one node, keep the original label + nodes[0].label = label + labels_to_nodes[label] = nodes[0] + + self.labels_to_nodes = labels_to_nodes + return node_lut, edge_lut, proposed_labels, labels_to_nodes + + def __getattr__(self, item: str) -> Any: + """Allow direct access to pipeline attributes""" + if item in self.labels_to_nodes: + return self.labels_to_nodes[item] + raise AttributeError(f"Pipeline has no attribute '{item}'") + + def __dir__(self): + # Include both regular attributes and dynamic ones + return list(super().__dir__()) + list(self.labels_to_nodes.keys()) + + @classmethod def load(cls, path: Path | str) -> "Pipeline": """Load complete pipeline state""" @@ -74,7 +141,7 @@ def load(cls, path: Path | str) -> "Pipeline": with open(path, "rb") as f: state = pickle.load(f) - pipeline = cls(state["name"]) + pipeline = cls(state["name"], state["output_store"]) pipeline.invocation_lut = state["invocation_lut"] logger.info(f"Pipeline '{pipeline.name}' loaded from {path}") @@ -103,628 +170,5 @@ def _validate_serializable(self) -> None: + "\n".join(f" - {issue}" for issue in issues) + "\n\nOnly named functions are supported for serialization." ) + - # View Management - def as_view( - self, renderer: "ViewRenderer", view_id: str | None = None, **kwargs - ) -> "PipelineView": - """Get a view of this pipeline using the specified renderer""" - view_id = ( - view_id - or f"{renderer.__class__.__name__.lower()}_{len(self._view_registry)}" - ) - - if view_id not in self._view_registry: - self._view_registry[view_id] = renderer.create_view( - self, view_id=view_id, **kwargs - ) - return self._view_registry[view_id] - - def as_dataframe(self, view_id: str = "default", **kwargs) -> "PandasPipelineView": - """Convenience method for pandas DataFrame view""" - return self.as_view(PandasViewRenderer(), view_id=view_id, **kwargs) - - def as_graph(self) -> nx.DiGraph: - """Get the computation graph""" - return self.generate_graph() - - # Combined save/load with views - def save_with_views(self, base_path: Path | str) -> dict[str, Path]: - """Save pipeline and all its views together""" - base_path = Path(base_path) - base_path.mkdir(parents=True, exist_ok=True) - - saved_files = {} - - # Save pipeline itself - pipeline_path = base_path / "pipeline.pkl" - self.save(pipeline_path) - saved_files["pipeline"] = pipeline_path - - # Save all views - for view_id, view in self._view_registry.items(): - view_path = base_path / f"view_{view_id}.pkl" - view.save(view_path, include_pipeline=False) - saved_files[f"view_{view_id}"] = view_path - - # Save manifest - manifest = { - "pipeline_file": "pipeline.pkl", - "views": { - view_id: f"view_{view_id}.pkl" for view_id in self._view_registry.keys() - }, - "created_at": time.time(), - "pipeline_name": self.name, - } - - manifest_path = base_path / "manifest.json" - with open(manifest_path, "w") as f: - json.dump(manifest, f, indent=2) - saved_files["manifest"] = manifest_path - - return saved_files - - @classmethod - def load_with_views( - cls, base_path: Path | str - ) -> tuple["Pipeline", dict[str, "PipelineView"]]: - """Load pipeline and all its views""" - base_path = Path(base_path) - - # Load manifest - manifest_path = base_path / "manifest.json" - with open(manifest_path, "r") as f: - manifest = json.load(f) - - # Load pipeline - pipeline_path = base_path / manifest["pipeline_file"] - pipeline = cls.load(pipeline_path) - - # Load views with appropriate renderers - renderers = { - "PandasPipelineView": PandasViewRenderer(), - "DataJointPipelineView": DataJointViewRenderer(None), # Would need schema - } - - views = {} - for view_id, view_file in manifest["views"].items(): - view_path = base_path / view_file - - # Load view data to determine type - with open(view_path, "rb") as f: - view_data = pickle.load(f) - - # Find appropriate renderer - view_type = view_data.get("view_type", "PandasPipelineView") - if view_type in renderers and renderers[view_type].can_load_view(view_data): - # Load with appropriate view class - if view_type == "PandasPipelineView": - view = PandasPipelineView.load(view_path, pipeline) - else: - view = DataJointPipelineView.load(view_path, pipeline) - else: - # Default to pandas view - view = PandasPipelineView.load(view_path, pipeline) - - views[view_id] = view - pipeline._view_registry[view_id] = view - - return pipeline, views - - def get_stats(self) -> dict[str, Any]: - """Get pipeline statistics""" - total_operations = len(self.invocation_lut) - total_invocations = sum(len(invs) for invs in self.invocation_lut.values()) - - operation_types = {} - for operation in self.invocation_lut.keys(): - op_type = operation.__class__.__name__ - operation_types[op_type] = operation_types.get(op_type, 0) + 1 - - return { - "name": self.name, - "total_operations": total_operations, - "total_invocations": total_invocations, - "operation_types": operation_types, - "views": list(self._view_registry.keys()), - } - - -# View Renderer Protocol -@runtime_checkable -class ViewRenderer(Protocol): - """Protocol for all view renderers - uses structural typing""" - - def create_view( - self, pipeline: "Pipeline", view_id: str, **kwargs - ) -> "PipelineView": - """Create a view for the given pipeline""" - ... - - def can_load_view(self, view_data: dict[str, Any]) -> bool: - """Check if this renderer can load the given view data""" - ... - - -class PandasViewRenderer: - """Renderer for pandas DataFrame views""" - - def create_view( - self, pipeline: "Pipeline", view_id: str, **kwargs - ) -> "PandasPipelineView": - return PandasPipelineView(pipeline, view_id=view_id, **kwargs) - - def can_load_view(self, view_data: dict[str, Any]) -> bool: - return view_data.get("view_type") == "PandasPipelineView" - - -class DataJointViewRenderer: - """Renderer for DataJoint views""" - - def __init__(self, schema): - self.schema = schema - - def create_view( - self, pipeline: "Pipeline", view_id: str, **kwargs - ) -> "DataJointPipelineView": - return DataJointPipelineView(pipeline, self.schema, view_id=view_id, **kwargs) - - def can_load_view(self, view_data: dict[str, Any]) -> bool: - return view_data.get("view_type") == "DataJointPipelineView" - - -# Base class for all views -class PipelineView(ABC): - """Base class for all pipeline views""" - - def __init__(self, pipeline: Pipeline, view_id: str): - self.pipeline = pipeline - self.view_id = view_id - self._cache_dir = pipeline._cache_dir / "views" - self._cache_dir.mkdir(parents=True, exist_ok=True) - - @abstractmethod - def save(self, path: Path | str, include_pipeline: bool = True) -> None: - """Save the view""" - pass - - @classmethod - @abstractmethod - def load(cls, path: Path | str, pipeline: Pipeline | None = None) -> "PipelineView": - """Load the view""" - pass - - def _compute_pipeline_hash(self) -> str: - """Compute hash of current pipeline state for validation""" - pipeline_state = [] - for operation, invocations in self.pipeline.invocation_lut.items(): - for invocation in invocations: - pipeline_state.append(invocation.content_hash()) - return hash_to_hex(sorted(pipeline_state)) - - -# Pandas DataFrame-like view -class PandasPipelineView(PipelineView): - """ - Provides a pandas DataFrame-like interface to pipeline metadata. - Focuses on tag information for querying and filtering. - """ - - def __init__( - self, - pipeline: Pipeline, - view_id: str = "pandas_view", - max_records: int = 10000, - sample_size: int = 100, - ): - super().__init__(pipeline, view_id) - self.max_records = max_records - self.sample_size = sample_size - self._cached_data: pd.DataFrame | None = None - self._build_options = {"max_records": max_records, "sample_size": sample_size} - self._hash_to_data_map: dict[str, Any] = {} - - @property - def df(self) -> pd.DataFrame: - """Access the underlying DataFrame, building if necessary""" - if self._cached_data is None: - # Try to load from cache first - cache_path = self._cache_dir / f"{self.view_id}.pkl" - if cache_path.exists(): - try: - loaded_view = self.load(cache_path, self.pipeline) - if self._is_cache_valid(loaded_view): - self._cached_data = loaded_view._cached_data - self._hash_to_data_map = loaded_view._hash_to_data_map - logger.info(f"Loaded view '{self.view_id}' from cache") - return self._cached_data - except Exception as e: - logger.warning(f"Failed to load cached view: {e}") - - # Build from scratch - logger.info(f"Building view '{self.view_id}' from pipeline") - self._cached_data = self._build_metadata() - - # Auto-save after building - try: - self.save(cache_path, include_pipeline=False) - except Exception as e: - logger.warning(f"Failed to cache view: {e}") - - return self._cached_data - - def _build_metadata(self) -> pd.DataFrame: - """Build the metadata DataFrame from pipeline operations""" - metadata_records = [] - total_records = 0 - - for operation, invocations in self.pipeline.invocation_lut.items(): - if total_records >= self.max_records: - logger.warning(f"Hit max_records limit ({self.max_records})") - break - - for invocation in invocations: - try: - # Get sample of outputs, not all - records = self._extract_metadata_from_invocation( - invocation, operation - ) - for record in records: - metadata_records.append(record) - total_records += 1 - if total_records >= self.max_records: - break - - if total_records >= self.max_records: - break - - except Exception as e: - logger.warning(f"Skipping {operation.__class__.__name__}: {e}") - # Create placeholder record - placeholder = self._create_placeholder_record(invocation, operation) - metadata_records.append(placeholder) - total_records += 1 - - if not metadata_records: - # Return empty DataFrame with basic structure - return pd.DataFrame( - columns=[ - "operation_name", - "operation_hash", - "invocation_id", - "created_at", - "packet_keys", - ] - ) - - return pd.DataFrame(metadata_records) - - def _extract_metadata_from_invocation( - self, invocation: Invocation, operation: Kernel - ) -> list[dict[str, Any]]: - """Extract metadata records from a single invocation""" - records = [] - - # Try to get sample outputs from the invocation - try: - # This is tricky - we need to reconstruct the output stream - # For now, we'll create a basic record from what we know - base_record = { - "operation_name": operation.label or operation.__class__.__name__, - "operation_hash": invocation.content_hash(), - "invocation_id": hash(invocation), - "created_at": time.time(), - "operation_type": operation.__class__.__name__, - } - - # Try to get tag and packet info from the operation - try: - tag_keys, packet_keys = invocation.keys() - base_record.update( - { - "tag_keys": list(tag_keys) if tag_keys else [], - "packet_keys": list(packet_keys) if packet_keys else [], - } - ) - except Exception: - base_record.update( - { - "tag_keys": [], - "packet_keys": [], - } - ) - - records.append(base_record) - - except Exception as e: - logger.debug(f"Could not extract detailed metadata from {operation}: {e}") - records.append(self._create_placeholder_record(invocation, operation)) - - return records - - def _create_placeholder_record( - self, invocation: Invocation, operation: Kernel - ) -> dict[str, Any]: - """Create a placeholder record when extraction fails""" - return { - "operation_name": operation.label or operation.__class__.__name__, - "operation_hash": invocation.content_hash(), - "invocation_id": hash(invocation), - "created_at": time.time(), - "operation_type": operation.__class__.__name__, - "tag_keys": [], - "packet_keys": [], - "is_placeholder": True, - } - - # DataFrame-like interface - def __getitem__(self, condition) -> "FilteredPipelineView": - """Enable pandas-like filtering: view[condition]""" - df = self.df - if isinstance(condition, pd.Series): - filtered_df = df[condition] - elif callable(condition): - filtered_df = df[condition(df)] - else: - filtered_df = df[condition] - - return FilteredPipelineView(self.pipeline, filtered_df, self._hash_to_data_map) - - def query(self, expr: str) -> "FilteredPipelineView": - """SQL-like querying: view.query('operation_name == "MyOperation"')""" - df = self.df - filtered_df = df.query(expr) - return FilteredPipelineView(self.pipeline, filtered_df, self._hash_to_data_map) - - def groupby(self, *args, **kwargs) -> "GroupedPipelineView": - """Group operations similar to pandas groupby""" - df = self.df - grouped = df.groupby(*args, **kwargs) - return GroupedPipelineView(self.pipeline, grouped, self._hash_to_data_map) - - def head(self, n: int = 5) -> pd.DataFrame: - """Return first n rows""" - return self.df.head(n) - - def info(self) -> None: - """Display DataFrame info""" - return self.df.info() - - def describe(self) -> pd.DataFrame: - """Generate descriptive statistics""" - return self.df.describe() - - # Persistence methods - def save(self, path: Path | str, include_pipeline: bool = True) -> None: - """Save view, optionally with complete pipeline state""" - path = Path(path) - - # Build the view data if not cached - df = self.df - - view_data = { - "view_id": self.view_id, - "view_type": self.__class__.__name__, - "dataframe": df, - "build_options": self._build_options, - "hash_to_data_map": self._hash_to_data_map, - "created_at": time.time(), - "pipeline_hash": self._compute_pipeline_hash(), - } - - if include_pipeline: - view_data["pipeline_state"] = { - "name": self.pipeline.name, - "invocation_lut": self.pipeline.invocation_lut, - } - view_data["has_pipeline"] = True - else: - view_data["pipeline_name"] = self.pipeline.name - view_data["has_pipeline"] = False - - with open(path, "wb") as f: - pickle.dump(view_data, f, protocol=pickle.HIGHEST_PROTOCOL) - - @classmethod - def load( - cls, path: Path | str, pipeline: Pipeline | None = None - ) -> "PandasPipelineView": - """Load view, reconstructing pipeline if needed""" - with open(path, "rb") as f: - view_data = pickle.load(f) - - # Handle pipeline reconstruction - if view_data["has_pipeline"]: - pipeline = Pipeline(view_data["pipeline_state"]["name"]) - pipeline.invocation_lut = view_data["pipeline_state"]["invocation_lut"] - elif pipeline is None: - raise ValueError( - "View was saved without pipeline state. " - "You must provide a pipeline parameter." - ) - - # Reconstruct view - build_options = view_data.get("build_options", {}) - view = cls( - pipeline, - view_id=view_data["view_id"], - max_records=build_options.get("max_records", 10000), - sample_size=build_options.get("sample_size", 100), - ) - view._cached_data = view_data["dataframe"] - view._hash_to_data_map = view_data.get("hash_to_data_map", {}) - - return view - - def _is_cache_valid(self, cached_view: "PandasPipelineView") -> bool: - """Check if cached view is still valid""" - try: - cached_hash = getattr(cached_view, "_pipeline_hash", None) - current_hash = self._compute_pipeline_hash() - return cached_hash == current_hash - except Exception: - return False - - def invalidate(self) -> None: - """Force re-rendering on next access""" - self._cached_data = None - cache_path = self._cache_dir / f"{self.view_id}.pkl" - if cache_path.exists(): - cache_path.unlink() - - -class FilteredPipelineView: - """Represents a filtered subset of pipeline metadata""" - - def __init__( - self, pipeline: Pipeline, filtered_df: pd.DataFrame, data_map: dict[str, Any] - ): - self.pipeline = pipeline - self.df = filtered_df - self._data_map = data_map - - def __getitem__(self, condition): - """Further filtering""" - further_filtered = self.df[condition] - return FilteredPipelineView(self.pipeline, further_filtered, self._data_map) - - def query(self, expr: str): - """Apply additional query""" - further_filtered = self.df.query(expr) - return FilteredPipelineView(self.pipeline, further_filtered, self._data_map) - - def to_pandas(self) -> pd.DataFrame: - """Convert to regular pandas DataFrame""" - return self.df.copy() - - def head(self, n: int = 5) -> pd.DataFrame: - """Return first n rows""" - return self.df.head(n) - - def __len__(self) -> int: - return len(self.df) - - def __repr__(self) -> str: - return f"FilteredPipelineView({len(self.df)} records)" - - -class GroupedPipelineView: - """Represents grouped pipeline metadata""" - - def __init__(self, pipeline: Pipeline, grouped_df, data_map: dict[str, Any]): - self.pipeline = pipeline - self.grouped = grouped_df - self._data_map = data_map - - def apply(self, func): - """Apply function to each group""" - return self.grouped.apply(func) - - def agg(self, *args, **kwargs): - """Aggregate groups""" - return self.grouped.agg(*args, **kwargs) - - def size(self): - """Get group sizes""" - return self.grouped.size() - - def get_group(self, name): - """Get specific group""" - group_df = self.grouped.get_group(name) - return FilteredPipelineView(self.pipeline, group_df, self._data_map) - - -# Basic DataJoint View (simplified implementation) -class DataJointPipelineView(PipelineView): - """ - Basic DataJoint view - creates tables for pipeline operations - This is a simplified version - you can expand based on your existing DJ code - """ - - def __init__(self, pipeline: Pipeline, schema, view_id: str = "dj_view"): - super().__init__(pipeline, view_id) - self.schema = schema - self._tables = {} - - def save(self, path: Path | str, include_pipeline: bool = True) -> None: - """Save DataJoint view metadata""" - view_data = { - "view_id": self.view_id, - "view_type": self.__class__.__name__, - "schema_database": self.schema.database, - "table_names": list(self._tables.keys()), - "created_at": time.time(), - } - - if include_pipeline: - view_data["pipeline_state"] = { - "name": self.pipeline.name, - "invocation_lut": self.pipeline.invocation_lut, - } - view_data["has_pipeline"] = True - - with open(path, "wb") as f: - pickle.dump(view_data, f) - - @classmethod - def load( - cls, path: Path | str, pipeline: Pipeline | None = None - ) -> "DataJointPipelineView": - """Load DataJoint view""" - with open(path, "rb") as f: - view_data = pickle.load(f) - - # This would need actual DataJoint schema reconstruction - # For now, return a basic instance - if pipeline is None: - raise ValueError("Pipeline required for DataJoint view loading") - - # You'd need to reconstruct the schema here - view = cls(pipeline, None, view_id=view_data["view_id"]) # schema=None for now - return view - - def generate_tables(self): - """Generate DataJoint tables from pipeline - placeholder implementation""" - # This would use your existing DataJoint generation logic - # from your dj/tracker.py file - pass - - -# Utility functions -def validate_pipeline_serializability(pipeline: Pipeline) -> None: - """Helper to check if pipeline can be saved""" - try: - pipeline._validate_serializable() - print("✅ Pipeline is ready for serialization") - - # Additional performance warnings - stats = pipeline.get_stats() - if stats["total_invocations"] > 1000: - print( - f"⚠️ Large pipeline ({stats['total_invocations']} invocations) - views may be slow to build" - ) - - except SerializationError as e: - print("❌ Pipeline cannot be serialized:") - print(str(e)) - print("\n💡 Convert lambda functions to named functions:") - print(" lambda x: x > 0.8 → def filter_func(x): return x > 0.8") - - -def create_example_pipeline() -> Pipeline: - """Create an example pipeline for testing""" - from orcapod import GlobSource, function_pod - - @function_pod - def example_function(input_file): - return f"processed_{input_file}" - - pipeline = Pipeline("example") - - with pipeline: - # This would need actual operations to be meaningful - # source = GlobSource('data', './test_data', '*.txt')() - # results = source >> example_function - pass - - return pipeline diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py new file mode 100644 index 0000000..55207c2 --- /dev/null +++ b/src/orcapod/pipeline/wrappers.py @@ -0,0 +1,667 @@ +from orcapod.core.pod import Pod, FunctionPod +from orcapod.core import SyncStream, Source, Kernel +from orcapod.store import ArrowDataStore +from orcapod.types import Tag, Packet, TypeSpec, default_registry +from orcapod.types.typespec import extract_function_typespecs +from orcapod.hashing import ObjectHasher, ArrowHasher +from orcapod.hashing.defaults import get_default_object_hasher, get_default_arrow_hasher +from typing import Any, Literal +from collections.abc import Collection, Iterator +from orcapod.types.registry import TypeRegistry, PacketConverter +import pyarrow as pa +import polars as pl +from orcapod.core.streams import SyncStreamFromGenerator +from orcapod.utils.stream_utils import get_typespec, merge_typespecs + +import logging +logger = logging.getLogger(__name__) + +def tag_to_arrow_table_with_metadata(tag, metadata: dict | None = None): + """ + Convert a tag dictionary to PyArrow table with metadata on each column. + + Args: + tag: Dictionary with string keys and any Python data type values + metadata_key: The metadata key to add to each column + metadata_value: The metadata value to indicate this column came from tag + """ + if metadata is None: + metadata = {} + + # First create the table to infer types + temp_table = pa.Table.from_pylist([tag]) + + # Create new fields with metadata + fields_with_metadata = [] + for field in temp_table.schema: + # Add metadata to each field + field_metadata = metadata + new_field = pa.field( + field.name, field.type, nullable=field.nullable, metadata=field_metadata + ) + fields_with_metadata.append(new_field) + + # Create schema with metadata + schema_with_metadata = pa.schema(fields_with_metadata) + + # Create the final table with the metadata-enriched schema + table = pa.Table.from_pylist([tag], schema=schema_with_metadata) + + return table + +def get_columns_with_metadata(df: pl.DataFrame, key: str, value: str|None = None) -> list[str]: + """Get column names with specific metadata using list comprehension. If value is given, only + columns matching that specific value for the desginated metadata key will be returned. + Otherwise, all columns that contains the key as metadata will be returned regardless of the value""" + return [ + col_name for col_name, dtype in df.schema.items() + if hasattr(dtype, "metadata") and (value is None or getattr(dtype, "metadata") == value) + ] + + +class PolarsSource(Source): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]|None = None): + self.df = df + self.tag_keys = tag_keys + + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "PolarsSource does not support forwarding streams. " + "It generates its own stream from the DataFrame." + ) + return PolarsStream(self.df, self.tag_keys) + + +class PolarsStream(SyncStream): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]|None = None): + self.df = df + if tag_keys is None: + # extract tag_keys by picking columns with metadata source=tag + tag_keys = get_columns_with_metadata(df, "source", "tag") + self.tag_keys = tag_keys + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + for row in self.df.iter_rows(named=True): + tag = {key: row[key] for key in self.tag_keys} + packet = {key: val for key, val in row.items() if key not in self.tag_keys} + yield tag, packet + +class EmptyStream(SyncStream): + def __init__(self, tag_keys: Collection[str]|None = None, packet_keys: Collection[str]|None = None, tag_typespec: TypeSpec | None = None, packet_typespec:TypeSpec|None = None): + if tag_keys is None and tag_typespec is not None: + tag_keys = tag_typespec.keys() + self.tag_keys = list(tag_keys) if tag_keys else [] + + if packet_keys is None and packet_typespec is not None: + packet_keys = packet_typespec.keys() + self.packet_keys = list(packet_keys) if packet_keys else [] + + self.tag_typespec = tag_typespec + self.packet_typespec = packet_typespec + + def keys(self, *streams: SyncStream, trigger_run: bool = False) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.tag_keys, self.packet_keys + + def types(self, *streams: SyncStream, trigger_run: bool = False) -> tuple[TypeSpec | None, TypeSpec | None]: + return self.tag_typespec, self.packet_typespec + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + # Empty stream, no data to yield + return iter([]) + + + + +class KernelInvocationWrapper(Kernel): + def __init__(self, kernel: Kernel, input_streams: Collection[SyncStream], **kwargs) -> None: + super().__init__(**kwargs) + self.kernel = kernel + self.input_streams = list(input_streams) + + + def __repr__(self): + return f"{self.__class__.__name__}<{self.kernel!r}>" + + def __str__(self): + return f"{self.__class__.__name__}<{self.kernel}>" + + @property + def label(self) -> str: + return self._label or self.kernel.label + + @label.setter + def label(self, label: str) -> None: + self._label = label + + def resolve_input_streams(self, *input_streams) -> Collection[SyncStream]: + if input_streams: + raise ValueError( + "Wrapped pod with specified streams cannot be invoked with additional streams" + ) + return self.input_streams + + def identity_structure(self, *streams: SyncStream) -> Any: + """ + Identity structure that includes the wrapped kernel's identity structure. + """ + resolved_streams = self.resolve_input_streams(*streams) + return self.kernel.identity_structure(*resolved_streams) + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + resolved_streams = self.resolve_input_streams(*streams) + return self.kernel.keys(*resolved_streams, trigger_run=trigger_run) + + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + resolved_streams = self.resolve_input_streams(*streams) + return self.kernel.types(*resolved_streams, trigger_run=trigger_run) + + def claims_unique_tags( + self, *streams: SyncStream, trigger_run: bool = False + ) -> bool | None: + resolved_streams = self.resolve_input_streams(*streams) + return self.kernel.claims_unique_tags( + *resolved_streams, trigger_run=trigger_run + ) + + +class CachedKernelWrapper(KernelInvocationWrapper, Source): + """ + A Kernel wrapper that wraps a kernel and stores the outputs of the kernel. + If the class is instantiated with input_streams that is not None, then this wrapper + will strictly represent the invocation of the wrapped Kernel on the given input streams. + Passing in an empty list into input_streams would still be registered as a specific invocation. + If input_streams is None, the class instance largely acts as a proxy of the underlying kernel + but will try to save all results. Note that depending on the storage type passed in, the saving + may error out if you invoke the instance on input streams with non-compatible schema (e.g., tags with + different keys). + """ + + def __init__( + self, + kernel: Kernel, + input_streams: Collection[SyncStream], + output_store: ArrowDataStore, + _object_hasher: ObjectHasher | None = None, + _arrow_hasher: ArrowHasher | None = None, + _registry: TypeRegistry | None = None, + **kwargs, + ) -> None: + super().__init__(kernel, input_streams,**kwargs) + + self.output_store = output_store + self.tag_keys, self.packet_keys = self.keys(trigger_run=False) + self.output_converter = None + + # These are configurable but are not expected to be modified except for special circumstances + if _object_hasher is None: + _object_hasher = get_default_object_hasher() + self.object_hasher = _object_hasher + if _arrow_hasher is None: + _arrow_hasher = get_default_arrow_hasher() + self.arrow_hasher = _arrow_hasher + if _registry is None: + _registry = default_registry + self.registry = _registry + self.source_info = self.label, self.object_hasher.hash_to_hex(self.kernel) + + self._cache_computed = False + + + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + if self._cache_computed: + logger.info(f"Returning cached outputs for {self}") + if self.df is not None: + return PolarsStream(self.df, tag_keys=self.tag_keys) + else: + return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.packet_keys) + + resolved_streams = self.resolve_input_streams(*streams) + + output_stream = self.kernel.forward(*resolved_streams, **kwargs) + + tag_type, packet_type = output_stream.types(trigger_run=False) + if tag_type is not None and packet_type is not None: + joined_type = merge_typespecs(tag_type, packet_type) + assert joined_type is not None, "Joined typespec should not be None" + self.output_converter = PacketConverter(joined_type, registry=self.registry) + + # Cache the output stream of the underlying kernel + # This is a no-op if the output stream is already cached + def generator() -> Iterator[tuple[Tag, Packet]]: + logger.info(f"Computing and caching outputs for {self}") + for tag, packet in output_stream: + merged_info = {**tag, **packet} + if self.output_converter is None: + joined_type = get_typespec(merged_info) + assert joined_type is not None, "Joined typespec should not be None" + self.output_converter = PacketConverter( + joined_type, registry=self.registry + ) + + output_table = self.output_converter.to_arrow_table(merged_info) + output_id = self.arrow_hasher.hash_table(output_table) + if not self.output_store.get_record(*self.source_info, output_id): + self.output_store.add_record( + *self.source_info, + output_id, + output_table, + ) + yield tag, packet + self._cache_computed = True + + return SyncStreamFromGenerator(generator) + + @property + def lazy_df(self) -> pl.LazyFrame | None: + return self.output_store.get_all_records_as_polars(*self.source_info) + + @property + def df(self) -> pl.DataFrame | None: + lazy_df = self.lazy_df + if lazy_df is None: + return None + return lazy_df.collect() + + + def reset_cache(self): + self._cache_computed = False + + + +class FunctionPodInvocationWrapper(KernelInvocationWrapper, Pod): + """ + Convenience class to wrap a function pod, providing default pass-through + implementations + """ + def __init__(self, function_pod: FunctionPod, input_streams: Collection[SyncStream], **kwargs): + + # note that this would be an alias to the self.kernel but here explicitly taken as function_pod + # for better type hints + # MRO will be KernelInvocationWrapper -> Pod -> Kernel + super().__init__(function_pod, input_streams, **kwargs) + self.function_pod = function_pod + + + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + resolved_streams = self.resolve_input_streams(*streams) + return super().forward(*resolved_streams, **kwargs) + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: + return self.function_pod.call(tag, packet) + + + # =============pass through methods/properties to the underlying function pod============= + + def set_active(self, active=True): + """ + Set the active state of the function pod. + """ + self.function_pod.set_active(active) + + def is_active(self) -> bool: + """ + Check if the function pod is active. + """ + return self.function_pod.is_active() + + + + + + +class CachedFunctionPodWrapper(FunctionPodInvocationWrapper, Source): + def __init__( + self, + function_pod: FunctionPod, + input_streams: Collection[SyncStream], + output_store: ArrowDataStore, + tag_store: ArrowDataStore | None = None, + label: str | None = None, + skip_memoization_lookup: bool = False, + skip_memoization: bool = False, + skip_tag_record: bool = False, + error_handling: Literal["raise", "ignore", "warn"] = "raise", + _object_hasher: ObjectHasher | None = None, + _arrow_hasher: ArrowHasher | None = None, + _registry: TypeRegistry | None = None, + **kwargs, + ) -> None: + super().__init__( + function_pod, + input_streams, + label=label, + error_handling=error_handling, + **kwargs, + ) + self.output_store = output_store + self.tag_store = tag_store + + self.skip_memoization_lookup = skip_memoization_lookup + self.skip_memoization = skip_memoization + self.skip_tag_record = skip_tag_record + + # These are configurable but are not expected to be modified except for special circumstances + if _object_hasher is None: + _object_hasher = get_default_object_hasher() + self.object_hasher = _object_hasher + if _arrow_hasher is None: + _arrow_hasher = get_default_arrow_hasher() + self.arrow_hasher = _arrow_hasher + if _registry is None: + _registry = default_registry + self.registry = _registry + + # TODO: consider making this dynamic + self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) + self.tag_keys, self.output_keys = self.keys(trigger_run=False) + + + # prepare packet converters + input_typespec, output_typespec = self.function_pod.get_function_typespecs() + + self.input_converter = PacketConverter(input_typespec, self.registry) + self.output_converter = PacketConverter(output_typespec, self.registry) + + self._cache_computed = False + + def reset_cache(self): + self._cache_computed = False + + def generator_completion_hook(self, n_computed: int) -> None: + """ + Hook to be called when the generator is completed. + """ + logger.info(f"Results cached for {self}") + self._cache_computed = True + + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + if self._cache_computed: + logger.info(f"Returning cached outputs for {self}") + if self.df is not None: + return PolarsStream(self.df, self.tag_keys) + else: + return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.output_keys) + logger.info(f"Computing and caching outputs for {self}") + return super().forward(*streams, **kwargs) + + + def get_packet_key(self, packet: Packet) -> str: + # TODO: reconsider the logic around input/output converter -- who should own this? + return self.arrow_hasher.hash_table( + self.input_converter.to_arrow_table(packet) + ) + + @property + def source_info(self): + return self.function_pod.function_name, self.function_pod_hash + + def is_memoized(self, packet: Packet) -> bool: + return self.retrieve_memoized(packet) is not None + + def add_tag_record(self, tag: Tag, packet: Packet) -> Tag: + """ + Record the tag for the packet in the record store. + This is used to keep track of the tags associated with memoized packets. + """ + return self._add_tag_record_with_packet_key(tag, self.get_packet_key(packet)) + + def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: + if self.tag_store is None: + raise ValueError("Recording of tag requires tag_store but none provided") + + tag = dict(tag) # ensure we don't modify the original tag + tag["__packet_key"] = packet_key + + # TODO: consider making this more efficient + # convert tag to arrow table - columns are labeled with metadata source=tag + table = tag_to_arrow_table_with_metadata(tag, {"source": "tag"}) + + entry_hash = self.arrow_hasher.hash_table(table) + + # TODO: add error handling + # check if record already exists: + retrieved_table = self.tag_store.get_record(*self.source_info, entry_hash) + if retrieved_table is None: + self.tag_store.add_record(*self.source_info, entry_hash, table) + + return tag + + def retrieve_memoized(self, packet: Packet) -> Packet | None: + """ + Retrieve a memoized packet from the data store. + Returns None if no memoized packet is found. + """ + logger.debug("Retrieving memoized packet") + return self._retrieve_memoized_with_packet_key(self.get_packet_key(packet)) + + def _retrieve_memoized_with_packet_key(self, packet_key: str) -> Packet | None: + """ + Retrieve a memoized result packet from the data store, looking up by the packet key + Returns None if no memoized packet is found. + """ + logger.debug(f"Retrieving memoized packet with key {packet_key}") + arrow_table = self.output_store.get_record( + self.function_pod.function_name, + self.function_pod_hash, + packet_key, + ) + if arrow_table is None: + return None + packets = self.function_pod.output_converter.from_arrow_table(arrow_table) + # since memoizing single packet, it should only contain one packet + assert len(packets) == 1, ( + f"Memoizing single packet return {len(packets)} packets!" + ) + return packets[0] + + def memoize( + self, + packet: Packet, + output_packet: Packet, + ) -> Packet: + """ + Memoize the output packet in the data store. + Returns the memoized packet. + """ + logger.debug("Memoizing packet") + return self._memoize_with_packet_key(self.get_packet_key(packet), output_packet) + + def _memoize_with_packet_key( + self, packet_key: str, output_packet: Packet + ) -> Packet: + """ + Memoize the output packet in the data store, looking up by packet key. + Returns the memoized packet. + """ + logger.debug(f"Memoizing packet with key {packet_key}") + # TODO: this logic goes through the entire store and retrieve cycle with two conversions + # consider simpler alternative + packets = self.output_converter.from_arrow_table( + self.output_store.add_record( + self.function_pod.function_name, + self.function_pod_hash, + packet_key, + self.output_converter.to_arrow_table(output_packet), + ) + ) + # since passed in a single packet, it should only return a single packet + assert len(packets) == 1, ( + f"Memoizing single packet returned {len(packets)} packets!" + ) + return packets[0] + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: + packet_key = "" + if ( + not self.skip_tag_record + or not self.skip_memoization_lookup + or not self.skip_memoization + ): + packet_key = self.get_packet_key(packet) + + output_packet = None + if not self.skip_memoization_lookup: + output_packet = self._retrieve_memoized_with_packet_key(packet_key) + if output_packet is not None: + logger.debug( + f"Memoized output for {packet} with {packet_key} found, skipping computation" + ) + else: + logger.debug( + f"Memoized output for packet {packet} with {packet_key} not found" + ) + + if output_packet is None: + # TODO: revisit the logic around active state and how to use it + tag, output_packet = self.function_pod.call(tag, packet) + if output_packet is not None and not self.skip_memoization: + # output packet may be modified by the memoization process + # e.g. if the output is a file, the path may be changed + output_packet = self._memoize_with_packet_key(packet_key, output_packet) # type: ignore + + if output_packet is None: + if self.is_active(): + logger.warning( + f"Function pod {self.function_pod.function_name} returned None for packet {packet} despite being active" + ) + return tag, None + + # result was successfully computed -- save the tag + if not self.skip_tag_record and self.tag_store is not None: + self._add_tag_record_with_packet_key(tag, packet_key) + + return tag, output_packet + + def get_all_outputs(self) -> pl.LazyFrame | None: + return self.output_store.get_all_records_as_polars(*self.source_info) + + def get_all_tags(self, with_packet_id: bool = False) -> pl.LazyFrame | None: + if self.tag_store is None: + raise ValueError("Tag store is not set, no tag record can be retrieved") + data = self.tag_store.get_all_records_as_polars(*self.source_info) + if not with_packet_id: + return data.drop("__packet_key") if data is not None else None + return data + + def get_all_entries_with_tags(self) -> pl.LazyFrame | None: + """ + Retrieve all entries from the tag store with their associated tags. + Returns a DataFrame with columns for tag and packet key. + """ + if self.tag_store is None: + raise ValueError("Tag store is not set, no tag record can be retrieved") + + tag_records = self.tag_store.get_all_records_as_polars(*self.source_info) + if tag_records is None: + return None + result_packets = self.output_store.get_records_by_ids_as_polars( + *self.source_info, + tag_records.collect()["__packet_key"], + preserve_input_order=True, + ) + if result_packets is None: + return None + + return pl.concat([tag_records, result_packets], how="horizontal").drop( + ["__packet_key"] + ) + + @property + def df(self) -> pl.DataFrame | None: + lazy_df = self.lazy_df + if lazy_df is None: + return None + return lazy_df.collect() + + @property + def lazy_df(self) -> pl.LazyFrame | None: + return self.get_all_entries_with_tags() + + @property + def tags(self) -> pl.DataFrame | None: + data = self.get_all_tags() + if data is None: + return None + + return data.collect() + + @property + def outputs(self) -> pl.DataFrame | None: + """ + Retrieve all outputs from the result store as a DataFrame. + Returns None if no outputs are available. + """ + data = self.get_all_outputs() + if data is None: + return None + + return data.collect() + + +class DummyFunctionPod(Pod): + def __init__(self, function_name="dummy", **kwargs): + super().__init__(**kwargs) + self.function_name = function_name + + def set_active(self, active: bool = True): + # no-op + pass + + def is_active(self) -> bool: + return False + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: + raise NotImplementedError( + "DummyFunctionPod cannot be called, it is only used to access previously stored tags and outputs." + ) + + +# TODO: Create this instead using compositional pattern +class DummyCachedFunctionPod(CachedFunctionPodWrapper): + """ + Dummy for a cached function pod. This is convenient to just allow the user to access + previously stored function pod tags and outputs without requiring instantiating the identical + function used for computation. + + Consequently, this function pod CANNOT be used to compute and insert new entries into the storage. + """ + + def __init__(self, source_pod: CachedFunctionPodWrapper): + self._source_info = source_pod.source_info + self.output_store = source_pod.output_store + self.tag_store = source_pod.tag_store + self.function_pod = DummyFunctionPod(source_pod.function_pod.function_name) + + @property + def source_info(self) -> tuple[str, str]: + return self._source_info + + +class Node(KernelInvocationWrapper, Source): + def __init__(self, kernel: Kernel, input_nodes: Collection["Node"], **kwargs): + """ + Create a node that wraps a kernel and provides a Node interface. + This is useful for creating nodes in a pipeline that can be executed. + """ + return super().__init__(kernel, input_nodes, **kwargs) + + def reset_cache(self) -> None: ... + + + +class KernelNode(Node, CachedKernelWrapper): + """ + A node that wraps a Kernel and provides a Node interface. + This is useful for creating nodes in a pipeline that can be executed. + """ + +class FunctionPodNode(Node, CachedFunctionPodWrapper): + """ + A node that wraps a FunctionPod and provides a Node interface. + This is useful for creating nodes in a pipeline that can be executed. + """ \ No newline at end of file diff --git a/src/orcapod/pod/__init__.py b/src/orcapod/pod/__init__.py deleted file mode 100644 index 8567c2a..0000000 --- a/src/orcapod/pod/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .core import Pod, FunctionPod, function_pod, TypedFunctionPod, typed_function_pod - -__all__ = [ - "Pod", - "FunctionPod", - "function_pod", - "TypedFunctionPod", - "typed_function_pod", -] diff --git a/src/orcapod/pod/core.py b/src/orcapod/pod/core.py deleted file mode 100644 index 069e0de..0000000 --- a/src/orcapod/pod/core.py +++ /dev/null @@ -1,877 +0,0 @@ -import functools -import logging -import pickle -import warnings -from abc import abstractmethod -import pyarrow as pa -import sys -from collections.abc import Callable, Collection, Iterable, Iterator, Sequence -from typing import ( - Any, - Literal, -) - -from orcapod.types.registry import PacketConverter - -from orcapod.core.base import Kernel -from orcapod.hashing import ( - ObjectHasher, - ArrowHasher, - FunctionInfoExtractor, - get_function_signature, - hash_function, - get_default_object_hasher, - get_default_arrow_hasher, -) -from orcapod.core.operators import Join -from orcapod.store import DataStore, ArrowDataStore, NoOpDataStore -from orcapod.core.streams import SyncStream, SyncStreamFromGenerator -from orcapod.types import Packet, PathSet, PodFunction, Tag, TypeSpec - -from orcapod.types.default import default_registry -from orcapod.types.inference import ( - extract_function_data_types, - verify_against_typespec, - check_typespec_compatibility, -) -from orcapod.types.registry import is_packet_supported -import polars as pl - -logger = logging.getLogger(__name__) - - -def function_pod( - output_keys: Collection[str] | None = None, - function_name: str | None = None, - data_store: DataStore | None = None, - store_name: str | None = None, - function_hash_mode: Literal["signature", "content", "name", "custom"] = "name", - custom_hash: int | None = None, - force_computation: bool = False, - skip_memoization: bool = False, - error_handling: Literal["raise", "ignore", "warn"] = "raise", - **kwargs, -) -> Callable[..., "FunctionPod"]: - """ - Decorator that wraps a function in a FunctionPod instance. - - Args: - output_keys: Keys for the function output - force_computation: Whether to force computation - skip_memoization: Whether to skip memoization - - Returns: - FunctionPod instance wrapping the decorated function - """ - - def decorator(func) -> FunctionPod: - if func.__name__ == "": - raise ValueError("Lambda functions cannot be used with function_pod") - - if not hasattr(func, "__module__") or func.__module__ is None: - raise ValueError( - f"Function {func.__name__} must be defined at module level" - ) - - # Store the original function in the module for pickling purposes - # and make sure to change the name of the function - module = sys.modules[func.__module__] - base_function_name = func.__name__ - new_function_name = f"_original_{func.__name__}" - setattr(module, new_function_name, func) - # rename the function to be consistent and make it pickleable - setattr(func, "__name__", new_function_name) - setattr(func, "__qualname__", new_function_name) - - # Create the FunctionPod - pod = FunctionPod( - function=func, - output_keys=output_keys, - function_name=function_name or base_function_name, - data_store=data_store, - store_name=store_name, - function_hash_mode=function_hash_mode, - custom_hash=custom_hash, - force_computation=force_computation, - skip_memoization=skip_memoization, - error_handling=error_handling, - **kwargs, - ) - - return pod - - return decorator - - -class Pod(Kernel): - """ - An (abstract) base class for all pods. A pod can be seen as a special type of operation that - only operates on the packet content without reading tags. Consequently, no operation - of Pod can dependent on the tags of the packets. This is a design choice to ensure that - the pods act as pure functions which is a necessary condition to guarantee reproducibility. - """ - - def __init__( - self, error_handling: Literal["raise", "ignore", "warn"] = "raise", **kwargs - ): - super().__init__(**kwargs) - self.error_handling = error_handling - - def process_stream(self, *streams: SyncStream) -> list[SyncStream]: - """ - Prepare the incoming streams for execution in the pod. This default implementation - joins all the streams together and raises and error if no streams are provided. - """ - # if multiple streams are provided, join them - # otherwise, return as is - combined_streams = list(streams) - if len(streams) > 1: - stream = streams[0] - for next_stream in streams[1:]: - stream = Join()(stream, next_stream) - combined_streams = [stream] - return combined_streams - - def __call__(self, *streams: SyncStream, **kwargs) -> SyncStream: - stream = self.process_stream(*streams) - return super().__call__(*stream, **kwargs) - - def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet]: ... - - def forward(self, *streams: SyncStream) -> SyncStream: - # if multiple streams are provided, join them - if len(streams) > 1: - raise ValueError("Multiple streams should be joined before calling forward") - if len(streams) == 0: - raise ValueError("No streams provided to forward") - stream = streams[0] - - def generator() -> Iterator[tuple[Tag, Packet]]: - n_computed = 0 - for tag, packet in stream: - try: - tag, output_packet = self.call(tag, packet) - n_computed += 1 - logger.info(f"Computed item {n_computed}") - yield tag, output_packet - - except Exception as e: - logger.error(f"Error processing packet {packet}: {e}") - if self.error_handling == "raise": - raise e - elif self.error_handling == "ignore": - continue - elif self.error_handling == "warn": - warnings.warn(f"Error processing packet {packet}: {e}") - continue - - return SyncStreamFromGenerator(generator) - - -class FunctionPod(Pod): - """ - A pod that wraps a function and allows it to be used as an operation in a stream. - This pod can be used to apply a function to the packets in a stream, with optional memoization - and caching of results. It can also handle multiple output keys and error handling. - The function should accept keyword arguments that correspond to the keys in the packets. - The output of the function should be a path or a collection of paths that correspond to the output keys.""" - - def __init__( - self, - function: PodFunction, - output_keys: Collection[str] | None = None, - function_name=None, - data_store: DataStore | None = None, - store_name: str | None = None, - function_hash_mode: Literal["signature", "content", "name", "custom"] = "name", - custom_hash: int | None = None, - label: str | None = None, - skip_computation: bool = False, - force_computation: bool = False, - skip_memoization_lookup: bool = False, - skip_memoization: bool = False, - error_handling: Literal["raise", "ignore", "warn"] = "raise", - _hash_function_kwargs: dict | None = None, - **kwargs, - ) -> None: - super().__init__(label=label, **kwargs) - self.function = function - self.output_keys = output_keys or [] - if function_name is None: - if hasattr(self.function, "__name__"): - function_name = getattr(self.function, "__name__") - else: - raise ValueError( - "function_name must be provided if function has no __name__ attribute" - ) - - self.function_name = function_name - self.data_store = data_store if data_store is not None else NoOpDataStore() - self.store_name = store_name or function_name - self.function_hash_mode = function_hash_mode - self.custom_hash = custom_hash - self.skip_computation = skip_computation - self.force_computation = force_computation - self.skip_memoization_lookup = skip_memoization_lookup - self.skip_memoization = skip_memoization - self.error_handling = error_handling - self._hash_function_kwargs = _hash_function_kwargs - - def __repr__(self) -> str: - func_sig = get_function_signature(self.function) - return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" - - def keys( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[Collection[str] | None, Collection[str] | None]: - stream = self.process_stream(*streams) - tag_keys, _ = stream[0].keys(trigger_run=trigger_run) - return tag_keys, tuple(self.output_keys) - - def is_memoized(self, packet: Packet) -> bool: - return self.retrieve_memoized(packet) is not None - - def retrieve_memoized(self, packet: Packet) -> Packet | None: - """ - Retrieve a memoized packet from the data store. - Returns None if no memoized packet is found. - """ - return self.data_store.retrieve_memoized( - self.store_name, - self.content_hash(char_count=16), - packet, - ) - - def memoize( - self, - packet: Packet, - output_packet: Packet, - ) -> Packet: - """ - Memoize the output packet in the data store. - Returns the memoized packet. - """ - return self.data_store.memoize( - self.store_name, - self.content_hash(char_count=16), # identity of this function pod - packet, - output_packet, - ) - - def forward(self, *streams: SyncStream) -> SyncStream: - # if multiple streams are provided, join them - if len(streams) > 1: - raise ValueError("Multiple streams should be joined before calling forward") - if len(streams) == 0: - raise ValueError("No streams provided to forward") - stream = streams[0] - - def generator() -> Iterator[tuple[Tag, Packet]]: - n_computed = 0 - for tag, packet in stream: - output_values: list["PathSet"] = [] - try: - if not self.skip_memoization_lookup: - memoized_packet = self.retrieve_memoized(packet) - else: - memoized_packet = None - if not self.force_computation and memoized_packet is not None: - logger.info("Memoized packet found, skipping computation") - yield tag, memoized_packet - continue - if self.skip_computation: - logger.info("Skipping computation as per configuration") - continue - values = self.function(**packet) - - if len(self.output_keys) == 0: - output_values = [] - elif len(self.output_keys) == 1: - output_values = [values] # type: ignore - elif isinstance(values, Iterable): - output_values = list(values) # type: ignore - elif len(self.output_keys) > 1: - raise ValueError( - "Values returned by function must be a pathlike or a sequence of pathlikes" - ) - - if len(output_values) != len(self.output_keys): - raise ValueError( - f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" - ) - except Exception as e: - logger.error(f"Error processing packet {packet}: {e}") - if self.error_handling == "raise": - raise e - elif self.error_handling == "ignore": - continue - elif self.error_handling == "warn": - warnings.warn(f"Error processing packet {packet}: {e}") - continue - else: - raise ValueError( - f"Unknown error handling mode: {self.error_handling} encountered while handling error:" - ) from e - - output_packet: Packet = { - k: v for k, v in zip(self.output_keys, output_values) - } - - if not self.skip_memoization: - # output packet may be modified by the memoization process - # e.g. if the output is a file, the path may be changed - output_packet = self.memoize(packet, output_packet) # type: ignore - - n_computed += 1 - logger.info(f"Computed item {n_computed}") - yield tag, output_packet - - return SyncStreamFromGenerator(generator) - - def identity_structure(self, *streams) -> Any: - content_kwargs = self._hash_function_kwargs - if self.function_hash_mode == "content": - if content_kwargs is None: - content_kwargs = { - "include_name": False, - "include_module": False, - "include_declaration": False, - } - function_hash_value = hash_function( - self.function, - name_override=self.function_name, - function_hash_mode="content", - content_kwargs=content_kwargs, - ) - elif self.function_hash_mode == "signature": - function_hash_value = hash_function( - self.function, - name_override=self.function_name, - function_hash_mode="signature", - content_kwargs=content_kwargs, - ) - elif self.function_hash_mode == "name": - function_hash_value = hash_function( - self.function, - name_override=self.function_name, - function_hash_mode="name", - content_kwargs=content_kwargs, - ) - elif self.function_hash_mode == "custom": - if self.custom_hash is None: - raise ValueError("Custom hash function not provided") - function_hash_value = self.custom_hash - else: - raise ValueError( - f"Unknown function hash mode: {self.function_hash_mode}. " - "Must be one of 'content', 'signature', 'name', or 'custom'." - ) - - return ( - self.__class__.__name__, - function_hash_value, - tuple(self.output_keys), - ) + tuple(streams) - - -def typed_function_pod( - output_keys: str | Collection[str] | None = None, - function_name: str | None = None, - label: str | None = None, - result_store: ArrowDataStore | None = None, - tag_store: ArrowDataStore | None = None, - object_hasher: ObjectHasher | None = None, - arrow_hasher: ArrowHasher | None = None, - **kwargs, -) -> Callable[..., "TypedFunctionPod | CachedFunctionPod"]: - """ - Decorator that wraps a function in a FunctionPod instance. - - Args: - output_keys: Keys for the function output(s) - function_name: Name of the function pod; if None, defaults to the function name - **kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details. - - Returns: - FunctionPod instance wrapping the decorated function - """ - - def decorator(func) -> TypedFunctionPod | CachedFunctionPod: - if func.__name__ == "": - raise ValueError("Lambda functions cannot be used with function_pod") - - if not hasattr(func, "__module__") or func.__module__ is None: - raise ValueError( - f"Function {func.__name__} must be defined at module level" - ) - - # Store the original function in the module for pickling purposes - # and make sure to change the name of the function - module = sys.modules[func.__module__] - base_function_name = func.__name__ - new_function_name = f"_original_{func.__name__}" - setattr(module, new_function_name, func) - # rename the function to be consistent and make it pickleable - setattr(func, "__name__", new_function_name) - setattr(func, "__qualname__", new_function_name) - - # Create a simple typed function pod - pod = TypedFunctionPod( - function=func, - output_keys=output_keys, - function_name=function_name or base_function_name, - label=label, - **kwargs, - ) - - if result_store is not None: - pod = CachedFunctionPod( - function_pod=pod, - object_hasher=object_hasher - if object_hasher is not None - else get_default_object_hasher(), - arrow_hasher=arrow_hasher - if arrow_hasher is not None - else get_default_arrow_hasher(), - result_store=result_store, - tag_store=tag_store, - ) - - return pod - - return decorator - - -class TypedFunctionPod(Pod): - """ - A type-aware pod that wraps a function and provides automatic type validation and inference. - - This pod extends the base Pod functionality by automatically extracting and validating - type information from function signatures and user-provided specifications. It ensures - type safety by verifying that both input and output types are supported by the - configured type registry before execution. - - The TypedFunctionPod analyzes the wrapped function's signature to determine: - - Parameter types (from annotations or user-provided input_types) - - Return value types (from annotations or user-provided output_types) - - Type compatibility with the packet type registry - - Key Features: - - Automatic type extraction from function annotations - - Type override support via input_types and output_types parameters - - Registry-based type validation ensuring data compatibility - - Memoization support with type-aware caching - - Multiple output key handling with proper type mapping - - Comprehensive error handling for type mismatches - - Type Resolution Priority: - 1. User-provided input_types/output_types override function annotations - 2. Function parameter annotations are used when available - 3. Function return annotations are parsed for output type inference - 4. Error raised if types cannot be determined or are unsupported - - Args: - function: The function to wrap. Must accept keyword arguments corresponding - to packet keys and return values compatible with output_keys. - output_keys: Collection of string keys for the function outputs. For functions - returning a single value, provide a single key. For multiple returns - (tuple/list), provide keys matching the number of return items. - function_name: Optional name for the function. Defaults to function.__name__. - input_types: Optional mapping of parameter names to their types. Overrides - function annotations for specified parameters. - output_types: Optional type specification for return values. Can be: - - A dict mapping output keys to types (TypeSpec) - - A sequence of types mapped to output_keys in order - These override inferred types from function return annotations. - data_store: DataStore instance for memoization. Defaults to NoOpDataStore. - function_hasher: Hasher function for creating function identity hashes. - Required parameter - no default implementation available. - label: Optional label for the pod instance. - skip_memoization_lookup: If True, skips checking for memoized results. - skip_memoization: If True, disables memoization entirely. - error_handling: How to handle execution errors: - - "raise": Raise exceptions (default) - - "ignore": Skip failed packets silently - - "warn": Issue warnings and continue - packet_type_registry: Registry for validating packet types. Defaults to - the default registry if None. - **kwargs: Additional arguments passed to the parent Pod class and above. - - Raises: - ValueError: When: - - function_name cannot be determined and is not provided - - Input types are not supported by the registry - - Output types are not supported by the registry - - Type extraction fails due to missing annotations/specifications - NotImplementedError: When function_hasher is None (required parameter). - - Examples: - Basic usage with annotated function: - - >>> def process_data(text: str, count: int) -> tuple[str, int]: - ... return text.upper(), count * 2 - >>> - >>> pod = TypedFunctionPod( - ... function=process_data, - ... output_keys=['upper_text', 'doubled_count'], - ... function_hasher=my_hasher - ... ) - - Override types for legacy function: - - >>> def legacy_func(x, y): # No annotations - ... return x + y - >>> - >>> pod = TypedFunctionPod( - ... function=legacy_func, - ... output_keys=['sum'], - ... input_types={'x': int, 'y': int}, - ... output_types={'sum': int}, - ... function_hasher=my_hasher - ... ) - - Multiple outputs with sequence override: - - >>> def analyze(data: list) -> tuple[int, float, str]: - ... return len(data), sum(data), str(data) - >>> - >>> pod = TypedFunctionPod( - ... function=analyze, - ... output_keys=['count', 'total', 'repr'], - ... output_types=[int, float, str], # Override with sequence - ... function_hasher=my_hasher - ... ) - - Attributes: - function: The wrapped function. - output_keys: List of output key names. - function_name: Name identifier for the function. - function_input_types: Resolved input type specification. - function_output_types: Resolved output type specification. - registry: Type registry for validation. - data_store: DataStore instance for memoization. - function_hasher: Function hasher for identity computation. - skip_memoization_lookup: Whether to skip memoization lookups. - skip_memoization: Whether to disable memoization entirely. - error_handling: Error handling strategy. - - Note: - The TypedFunctionPod requires a function_hasher to be provided as there - is no default implementation. This hasher is used to create stable - identity hashes for memoization and caching purposes. - - Type validation occurs during initialization, ensuring that any type - incompatibilities are caught early rather than during stream processing. - """ - - def __init__( - self, - function: Callable[..., Any], - output_keys: str | Collection[str] | None = None, - function_name=None, - input_types: TypeSpec | None = None, - output_types: TypeSpec | Sequence[type] | None = None, - label: str | None = None, - packet_type_registry=None, - function_info_extractor: FunctionInfoExtractor | None = None, - **kwargs, - ) -> None: - super().__init__(label=label, **kwargs) - self.function = function - if output_keys is None: - output_keys = [] - if isinstance(output_keys, str): - output_keys = [output_keys] - self.output_keys = output_keys - if function_name is None: - if hasattr(self.function, "__name__"): - function_name = getattr(self.function, "__name__") - else: - raise ValueError( - "function_name must be provided if function has no __name__ attribute" - ) - self.function_name = function_name - - if packet_type_registry is None: - packet_type_registry = default_registry - - self.registry = packet_type_registry - self.function_info_extractor = function_info_extractor - - # extract input and output types from the function signature - function_input_types, function_output_types = extract_function_data_types( - self.function, - self.output_keys, - input_types=input_types, - output_types=output_types, - ) - - self.function_input_types = function_input_types - self.function_output_types = function_output_types - - # TODO: include explicit check of support during PacketConverter creation - self.input_converter = PacketConverter(self.function_input_types, self.registry) - self.output_converter = PacketConverter( - self.function_output_types, self.registry - ) - - # TODO: prepare a separate str and repr methods - def __repr__(self) -> str: - func_sig = get_function_signature(self.function) - return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" - - def call(self, tag, packet) -> tuple[Tag, Packet]: - output_values: list["PathSet"] = [] - - values = self.function(**packet) - - if len(self.output_keys) == 0: - output_values = [] - elif len(self.output_keys) == 1: - output_values = [values] # type: ignore - elif isinstance(values, Iterable): - output_values = list(values) # type: ignore - elif len(self.output_keys) > 1: - raise ValueError( - "Values returned by function must be a pathlike or a sequence of pathlikes" - ) - - if len(output_values) != len(self.output_keys): - raise ValueError( - f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" - ) - - output_packet: Packet = {k: v for k, v in zip(self.output_keys, output_values)} - return tag, output_packet - - def identity_structure(self, *streams) -> Any: - # construct identity structure for the function - # if function_info_extractor is available, use that but substitute the function_name - if self.function_info_extractor is not None: - function_info = self.function_info_extractor.extract_function_info( - self.function, - function_name=self.function_name, - input_types=self.function_input_types, - output_types=self.function_output_types, - ) - else: - # use basic information only - function_info = { - "name": self.function_name, - "input_types": self.function_input_types, - "output_types": self.function_output_types, - } - function_info["output_keys"] = tuple(self.output_keys) - - return ( - self.__class__.__name__, - function_info, - ) + tuple(streams) - - def keys( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[Collection[str] | None, Collection[str] | None]: - stream = self.process_stream(*streams) - tag_keys, _ = stream[0].keys(trigger_run=trigger_run) - return tag_keys, tuple(self.output_keys) - - -class CachedFunctionPod(Pod): - def __init__( - self, - function_pod: TypedFunctionPod, - object_hasher: ObjectHasher, - arrow_hasher: ArrowHasher, - result_store: ArrowDataStore, - tag_store: ArrowDataStore | None = None, - label: str | None = None, - skip_memoization_lookup: bool = False, - skip_memoization: bool = False, - skip_tag_record: bool = False, - error_handling: Literal["raise", "ignore", "warn"] = "raise", - **kwargs, - ) -> None: - super().__init__(label=label, error_handling=error_handling, **kwargs) - self.function_pod = function_pod - - self.object_hasher = object_hasher - self.arrow_hasher = arrow_hasher - self.result_store = result_store - self.tag_store = tag_store - - self.skip_memoization_lookup = skip_memoization_lookup - self.skip_memoization = skip_memoization - self.skip_tag_record = skip_tag_record - - # TODO: consider making this dynamic - self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) - - def get_packet_key(self, packet: Packet) -> str: - return self.arrow_hasher.hash_table( - self.function_pod.input_converter.to_arrow_table(packet) - ) - - # TODO: prepare a separate str and repr methods - def __repr__(self) -> str: - return f"Cached:{self.function_pod}" - - def keys( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.function_pod.keys(*streams, trigger_run=trigger_run) - - def is_memoized(self, packet: Packet) -> bool: - return self.retrieve_memoized(packet) is not None - - def add_tag_record(self, tag: Tag, packet: Packet) -> Tag: - """ - Record the tag for the packet in the record store. - This is used to keep track of the tags associated with memoized packets. - """ - - return self._add_tag_record_with_packet_key(tag, self.get_packet_key(packet)) - - def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: - if self.tag_store is None: - raise ValueError("Recording of tag requires tag_store but none provided") - - tag = dict(tag) # ensure we don't modify the original tag - tag["__packet_key"] = packet_key - - # convert tag to arrow table - table = pa.Table.from_pylist([tag]) - - entry_hash = self.arrow_hasher.hash_table(table) - - # TODO: add error handling - # check if record already exists: - retrieved_table = self.tag_store.get_record( - self.function_pod.function_name, self.function_pod_hash, entry_hash - ) - if retrieved_table is None: - self.tag_store.add_record( - self.function_pod.function_name, - self.function_pod_hash, - entry_hash, - table, - ) - - return tag - - def retrieve_memoized(self, packet: Packet) -> Packet | None: - """ - Retrieve a memoized packet from the data store. - Returns None if no memoized packet is found. - """ - logger.info("Retrieving memoized packet") - return self._retrieve_memoized_by_hash(self.get_packet_key(packet)) - - def _retrieve_memoized_by_hash(self, packet_hash: str) -> Packet | None: - """ - Retrieve a memoized result packet from the data store, looking up by hash - Returns None if no memoized packet is found. - """ - logger.info(f"Retrieving memoized packet with hash {packet_hash}") - arrow_table = self.result_store.get_record( - self.function_pod.function_name, - self.function_pod_hash, - packet_hash, - ) - if arrow_table is None: - return None - packets = self.function_pod.output_converter.from_arrow_table(arrow_table) - # since memoizing single packet, it should only contain one packet - assert len(packets) == 1, ( - f"Memoizing single packet return {len(packets)} packets!" - ) - return packets[0] - - def memoize( - self, - packet: Packet, - output_packet: Packet, - ) -> Packet: - """ - Memoize the output packet in the data store. - Returns the memoized packet. - """ - logger.info("Memoizing packet") - return self._memoize_by_hash(self.get_packet_key(packet), output_packet) - - def _memoize_by_hash(self, packet_hash: str, output_packet: Packet) -> Packet: - """ - Memoize the output packet in the data store, looking up by hash. - Returns the memoized packet. - """ - logger.info(f"Memoizing packet with hash {packet_hash}") - packets = self.function_pod.output_converter.from_arrow_table( - self.result_store.add_record( - self.function_pod.function_name, - self.function_pod_hash, - packet_hash, - self.function_pod.output_converter.to_arrow_table(output_packet), - ) - ) - # since memoizing single packet, it should only contain one packet - assert len(packets) == 1, ( - f"Memoizing single packet return {len(packets)} packets!" - ) - return packets[0] - - def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet]: - packet_key = "" - if ( - not self.skip_tag_record - or not self.skip_memoization_lookup - or not self.skip_memoization - ): - packet_key = self.get_packet_key(packet) - - if not self.skip_tag_record and self.tag_store is not None: - self._add_tag_record_with_packet_key(tag, packet_key) - - if not self.skip_memoization_lookup: - memoized_packet = self._retrieve_memoized_by_hash(packet_key) - else: - memoized_packet = None - if memoized_packet is not None: - logger.info("Memoized packet found, skipping computation") - return tag, memoized_packet - - tag, output_packet = self.function_pod.call(tag, packet) - - if not self.skip_memoization: - # output packet may be modified by the memoization process - # e.g. if the output is a file, the path may be changed - output_packet = self.memoize(packet, output_packet) # type: ignore - - return tag, output_packet - - def get_all_entries_with_tags(self) -> pl.LazyFrame | None: - """ - Retrieve all entries from the tag store with their associated tags. - Returns a DataFrame with columns for tag and packet key. - """ - if self.tag_store is None: - raise ValueError("Tag store is not set, cannot retrieve entries") - - tag_records = self.tag_store.get_all_records_as_polars( - self.function_pod.function_name, self.function_pod_hash - ) - if tag_records is None: - return None - result_packets = self.result_store.get_records_by_ids_as_polars( - self.function_pod.function_name, - self.function_pod_hash, - tag_records.collect()["__packet_key"], - preserve_input_order=True, - ) - if result_packets is None: - return None - - return pl.concat([tag_records, result_packets], how="horizontal").drop( - ["__packet_key"] - ) - - def identity_structure(self, *streams) -> Any: - return self.function_pod.identity_structure(*streams) diff --git a/src/orcapod/store/__init__.py b/src/orcapod/store/__init__.py index f573c4d..281874b 100644 --- a/src/orcapod/store/__init__.py +++ b/src/orcapod/store/__init__.py @@ -1,5 +1,6 @@ from .types import DataStore, ArrowDataStore -from .core import DirDataStore, NoOpDataStore +from .arrow_data_stores import MockArrowDataStore, SimpleInMemoryDataStore +from .dict_data_stores import DirDataStore, NoOpDataStore from .safe_dir_data_store import SafeDirDataStore __all__ = [ @@ -8,4 +9,6 @@ "DirDataStore", "SafeDirDataStore", "NoOpDataStore", + "MockArrowDataStore", + "SimpleInMemoryDataStore", ] diff --git a/src/orcapod/store/arrow_data_stores.py b/src/orcapod/store/arrow_data_stores.py index 4be9698..e2c1376 100644 --- a/src/orcapod/store/arrow_data_stores.py +++ b/src/orcapod/store/arrow_data_stores.py @@ -23,17 +23,15 @@ class MockArrowDataStore: def __init__(self): logger.info("Initialized MockArrowDataStore") - def add_record(self, - source_name: str, - source_id: str, - entry_id: str, - arrow_data: pa.Table) -> pa.Table: + def add_record( + self, source_name: str, source_id: str, entry_id: str, arrow_data: pa.Table + ) -> pa.Table: """Add a record to the mock store.""" return arrow_data - def get_record(self, source_name: str, - source_id: str, - entry_id: str) -> pa.Table | None: + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: """Get a specific record.""" return None @@ -76,7 +74,7 @@ def get_records_by_ids( Arrow table containing all found records, or None if no records found """ return None - + def get_records_by_ids_as_polars( self, source_name: str, @@ -88,21 +86,19 @@ def get_records_by_ids_as_polars( return None - - -class InMemoryArrowDataStore: +class SimpleInMemoryDataStore: """ - In-memory Arrow data store for testing purposes. + In-memory Arrow data store, primarily to be used for testing purposes. This class simulates the behavior of ParquetArrowDataStore without actual file I/O. It is useful for unit tests where you want to avoid filesystem dependencies. - + Uses dict of dict of Arrow tables for efficient storage and retrieval. """ def __init__(self, duplicate_entry_behavior: str = "error"): """ Initialize the InMemoryArrowDataStore. - + Args: duplicate_entry_behavior: How to handle duplicate entry_ids: - 'error': Raise ValueError when entry_id already exists @@ -112,10 +108,12 @@ def __init__(self, duplicate_entry_behavior: str = "error"): if duplicate_entry_behavior not in ["error", "overwrite"]: raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") self.duplicate_entry_behavior = duplicate_entry_behavior - + # Store Arrow tables: {source_key: {entry_id: arrow_table}} self._in_memory_store: dict[str, dict[str, pa.Table]] = {} - logger.info(f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'") + logger.info( + f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" + ) def _get_source_key(self, source_name: str, source_id: str) -> str: """Generate key for source storage.""" @@ -127,40 +125,42 @@ def add_record( source_id: str, entry_id: str, arrow_data: pa.Table, + ignore_duplicate: bool = False, ) -> pa.Table: """ Add a record to the in-memory store. - + Args: source_name: Name of the data source source_id: ID of the specific dataset within the source entry_id: Unique identifier for this record arrow_data: The Arrow table data to store - + Returns: - The original arrow_data table - + arrow_data equivalent to having loaded the corresponding entry that was just saved + Raises: ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' """ source_key = self._get_source_key(source_name, source_id) - + # Initialize source if it doesn't exist if source_key not in self._in_memory_store: self._in_memory_store[source_key] = {} - + local_data = self._in_memory_store[source_key] - + # Check for duplicate entry - if entry_id in local_data and self.duplicate_entry_behavior == "error": - raise ValueError( - f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - + if entry_id in local_data: + if not ignore_duplicate and self.duplicate_entry_behavior == "error": + raise ValueError( + f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + # Store the record local_data[entry_id] = arrow_data - + action = "Updated" if entry_id in local_data else "Added" logger.debug(f"{action} record {entry_id} in {source_key}") return arrow_data @@ -173,24 +173,29 @@ def get_record( local_data = self._in_memory_store.get(source_key, {}) return local_data.get(entry_id) - def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + def get_all_records( + self, source_name: str, source_id: str, add_entry_id_column: bool | str = False + ) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" source_key = self._get_source_key(source_name, source_id) local_data = self._in_memory_store.get(source_key, {}) - + if not local_data: return None tables_with_keys = [] for key, table in local_data.items(): # Add entry_id column to each table - key_array = pa.array([key] * len(table), type=pa.string()) + key_array = pa.array([key] * len(table), type=pa.large_string()) table_with_key = table.add_column(0, "__entry_id", key_array) tables_with_keys.append(table_with_key) # Concatenate all tables if tables_with_keys: - return pa.concat_tables(tables_with_keys) + combined_table = pa.concat_tables(tables_with_keys) + if not add_entry_id_column: + combined_table = combined_table.drop(columns=["__entry_id"]) + return combined_table return None def get_all_records_as_polars( @@ -212,7 +217,7 @@ def get_records_by_ids( ) -> pa.Table | None: """ Retrieve records by entry IDs as a single table. - + Args: source_name: Name of the data source source_id: ID of the specific dataset within the source @@ -226,7 +231,7 @@ def get_records_by_ids( - str: Include entry ID column with custom name preserve_input_order: If True, return results in the same order as input entry_ids, with null rows for missing entries. If False, return in storage order. - + Returns: Arrow table containing all found records, or None if no records found """ @@ -250,18 +255,18 @@ def get_records_by_ids( source_key = self._get_source_key(source_name, source_id) local_data = self._in_memory_store.get(source_key, {}) - + if not local_data: return None # Collect matching tables found_tables = [] found_entry_ids = [] - + if preserve_input_order: # Preserve input order, include nulls for missing entries first_table_schema = None - + for entry_id in entry_ids_list: if entry_id in local_data: table = local_data[entry_id] @@ -270,7 +275,7 @@ def get_records_by_ids( table_with_key = table.add_column(0, "__entry_id", key_array) found_tables.append(table_with_key) found_entry_ids.append(entry_id) - + # Store schema for creating null rows if first_table_schema is None: first_table_schema = table_with_key.schema @@ -281,12 +286,14 @@ def get_records_by_ids( null_data = {} for field in first_table_schema: if field.name == "__entry_id": - null_data[field.name] = pa.array([entry_id], type=field.type) + null_data[field.name] = pa.array( + [entry_id], type=field.type + ) else: # Create null array with proper type null_array = pa.array([None], type=field.type) null_data[field.name] = null_array - + null_table = pa.table(null_data, schema=first_table_schema) found_tables.append(null_table) found_entry_ids.append(entry_id) @@ -315,12 +322,17 @@ def get_records_by_ids( # Remove the __entry_id column column_names = combined_table.column_names if "__entry_id" in column_names: - indices_to_keep = [i for i, name in enumerate(column_names) if name != "__entry_id"] + indices_to_keep = [ + i for i, name in enumerate(column_names) if name != "__entry_id" + ] combined_table = combined_table.select(indices_to_keep) elif isinstance(add_entry_id_column, str): # Rename __entry_id to custom name schema = combined_table.schema - new_names = [add_entry_id_column if name == "__entry_id" else name for name in schema.names] + new_names = [ + add_entry_id_column if name == "__entry_id" else name + for name in schema.names + ] combined_table = combined_table.rename_columns(new_names) # If add_entry_id_column is True, keep __entry_id as is @@ -336,7 +348,7 @@ def get_records_by_ids_as_polars( ) -> pl.LazyFrame | None: """ Retrieve records by entry IDs as a single Polars LazyFrame. - + Args: source_name: Name of the data source source_id: ID of the specific dataset within the source @@ -350,7 +362,7 @@ def get_records_by_ids_as_polars( - str: Include entry ID column with custom name preserve_input_order: If True, return results in the same order as input entry_ids, with null rows for missing entries. If False, return in storage order. - + Returns: Polars LazyFrame containing all found records, or None if no records found """ @@ -358,42 +370,42 @@ def get_records_by_ids_as_polars( arrow_result = self.get_records_by_ids( source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order ) - + if arrow_result is None: return None - + # Convert to Polars LazyFrame return pl.LazyFrame(arrow_result) def save_to_parquet(self, base_path: str | Path) -> None: """ Save all data to Parquet files in a directory structure. - + Directory structure: base_path/source_name/source_id/data.parquet - + Args: base_path: Base directory path where to save the Parquet files """ base_path = Path(base_path) base_path.mkdir(parents=True, exist_ok=True) - + saved_count = 0 - + for source_key, local_data in self._in_memory_store.items(): if not local_data: continue - + # Parse source_name and source_id from the key if ":" not in source_key: logger.warning(f"Invalid source key format: {source_key}, skipping") continue - + source_name, source_id = source_key.split(":", 1) - + # Create directory structure source_dir = base_path / source_name / source_id source_dir.mkdir(parents=True, exist_ok=True) - + # Combine all tables for this source with entry_id column tables_with_keys = [] for entry_id, table in local_data.items(): @@ -401,82 +413,89 @@ def save_to_parquet(self, base_path: str | Path) -> None: key_array = pa.array([entry_id] * len(table), type=pa.string()) table_with_key = table.add_column(0, "__entry_id", key_array) tables_with_keys.append(table_with_key) - + # Concatenate all tables if tables_with_keys: combined_table = pa.concat_tables(tables_with_keys) - + # Save as Parquet file + # TODO: perform safe "atomic" write parquet_path = source_dir / "data.parquet" import pyarrow.parquet as pq + pq.write_table(combined_table, parquet_path) - + saved_count += 1 - logger.debug(f"Saved {len(combined_table)} records for {source_key} to {parquet_path}") - + logger.debug( + f"Saved {len(combined_table)} records for {source_key} to {parquet_path}" + ) + logger.info(f"Saved {saved_count} sources to Parquet files in {base_path}") def load_from_parquet(self, base_path: str | Path) -> None: """ Load data from Parquet files with the expected directory structure. - + Expected structure: base_path/source_name/source_id/data.parquet - + Args: base_path: Base directory path containing the Parquet files """ base_path = Path(base_path) - + if not base_path.exists(): logger.warning(f"Base path {base_path} does not exist") return - + # Clear existing data self._in_memory_store.clear() - + loaded_count = 0 - + # Traverse directory structure: source_name/source_id/ for source_name_dir in base_path.iterdir(): if not source_name_dir.is_dir(): continue - + source_name = source_name_dir.name - + for source_id_dir in source_name_dir.iterdir(): if not source_id_dir.is_dir(): continue - + source_id = source_id_dir.name source_key = self._get_source_key(source_name, source_id) - + # Look for Parquet files in this directory parquet_files = list(source_id_dir.glob("*.parquet")) - + if not parquet_files: logger.debug(f"No Parquet files found in {source_id_dir}") continue - + # Load all Parquet files and combine them all_records = [] - + for parquet_file in parquet_files: try: import pyarrow.parquet as pq + table = pq.read_table(parquet_file) - + # Validate that __entry_id column exists if "__entry_id" not in table.column_names: - logger.warning(f"Parquet file {parquet_file} missing __entry_id column, skipping") + logger.warning( + f"Parquet file {parquet_file} missing __entry_id column, skipping" + ) continue - + all_records.append(table) logger.debug(f"Loaded {len(table)} records from {parquet_file}") - + except Exception as e: logger.error(f"Failed to load Parquet file {parquet_file}: {e}") continue - + # Process all records for this source if all_records: # Combine all tables @@ -484,44 +503,53 @@ def load_from_parquet(self, base_path: str | Path) -> None: combined_table = all_records[0] else: combined_table = pa.concat_tables(all_records) - + # Split back into individual records by entry_id local_data = {} entry_ids = combined_table.column("__entry_id").to_pylist() - + # Group records by entry_id entry_id_groups = {} for i, entry_id in enumerate(entry_ids): if entry_id not in entry_id_groups: entry_id_groups[entry_id] = [] entry_id_groups[entry_id].append(i) - + # Extract each entry_id's records for entry_id, indices in entry_id_groups.items(): # Take rows for this entry_id and remove __entry_id column entry_table = combined_table.take(indices) - + # Remove __entry_id column column_names = entry_table.column_names if "__entry_id" in column_names: - indices_to_keep = [i for i, name in enumerate(column_names) if name != "__entry_id"] + indices_to_keep = [ + i + for i, name in enumerate(column_names) + if name != "__entry_id" + ] entry_table = entry_table.select(indices_to_keep) - + local_data[entry_id] = entry_table - + self._in_memory_store[source_key] = local_data loaded_count += 1 - + record_count = len(combined_table) unique_entries = len(entry_id_groups) - logger.debug(f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}") - + logger.debug( + f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}" + ) + logger.info(f"Loaded {loaded_count} sources from Parquet files in {base_path}") - + # Log summary of loaded data - total_records = sum(len(local_data) for local_data in self._in_memory_store.values()) + total_records = sum( + len(local_data) for local_data in self._in_memory_store.values() + ) logger.info(f"Total records loaded: {total_records}") + @dataclass class RecordMetadata: """Metadata for a stored record.""" @@ -1634,7 +1662,7 @@ def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: store.add_record( "experiments", "dataset_A", valid_entries[0], overwrite_data ) - print(f"✓ Overwrote existing record") + print("✓ Overwrote existing record") # Verify overwrite updated_record = store.get_record( @@ -1648,7 +1676,7 @@ def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: # Sync and show final stats store.force_sync() stats = store.get_stats() - print(f"\n=== Final Statistics ===") + print("\n=== Final Statistics ===") print(f"Total records: {stats['total_records']}") print(f"Loaded caches: {stats['loaded_source_caches']}") print(f"Dirty caches: {stats['dirty_caches']}") @@ -1659,6 +1687,377 @@ def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: print("\n✓ Single-row constraint testing completed successfully!") +class InMemoryPolarsDataStore: + """ + In-memory Arrow data store using Polars DataFrames for efficient storage and retrieval. + This class provides the same interface as InMemoryArrowDataStore but uses Polars internally + for better performance with large datasets and complex queries. + + Uses dict of Polars DataFrames for efficient storage and retrieval. + Each DataFrame contains all records for a source with an __entry_id column. + """ + + def __init__(self, duplicate_entry_behavior: str = "error"): + """ + Initialize the InMemoryPolarsDataStore. + + Args: + duplicate_entry_behavior: How to handle duplicate entry_ids: + - 'error': Raise ValueError when entry_id already exists + - 'overwrite': Replace existing entry with new data + """ + # Validate duplicate behavior + if duplicate_entry_behavior not in ["error", "overwrite"]: + raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") + self.duplicate_entry_behavior = duplicate_entry_behavior + + # Store Polars DataFrames: {source_key: polars_dataframe} + # Each DataFrame has an __entry_id column plus user data columns + self._in_memory_store: dict[str, pl.DataFrame] = {} + logger.info( + f"Initialized InMemoryPolarsDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" + ) + + def _get_source_key(self, source_name: str, source_id: str) -> str: + """Generate key for source storage.""" + return f"{source_name}:{source_id}" + + def add_record( + self, + source_name: str, + source_id: str, + entry_id: str, + arrow_data: pa.Table, + ) -> pa.Table: + """ + Add a record to the in-memory store. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_id: Unique identifier for this record + arrow_data: The Arrow table data to store + + Returns: + arrow_data equivalent to having loaded the corresponding entry that was just saved + + Raises: + ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' + """ + source_key = self._get_source_key(source_name, source_id) + + # Convert Arrow table to Polars DataFrame and add entry_id column + polars_data = cast(pl.DataFrame, pl.from_arrow(arrow_data)) + + # Add __entry_id column + polars_data = polars_data.with_columns(pl.lit(entry_id).alias("__entry_id")) + + # Check if source exists + if source_key not in self._in_memory_store: + # First record for this source + self._in_memory_store[source_key] = polars_data + logger.debug(f"Created new source {source_key} with entry {entry_id}") + else: + existing_df = self._in_memory_store[source_key] + + # Check for duplicate entry + entry_exists = ( + existing_df.filter(pl.col("__entry_id") == entry_id).shape[0] > 0 + ) + + if entry_exists: + if self.duplicate_entry_behavior == "error": + raise ValueError( + f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + else: # validity of value is checked in constructor so it must be "ovewrite" + # Remove existing entry and add new one + existing_df = existing_df.filter(pl.col("__entry_id") != entry_id) + self._in_memory_store[source_key] = pl.concat( + [existing_df, polars_data] + ) + logger.debug(f"Overwrote entry {entry_id} in {source_key}") + else: + # Append new entry + try: + self._in_memory_store[source_key] = pl.concat( + [existing_df, polars_data] + ) + logger.debug(f"Added entry {entry_id} to {source_key}") + except Exception as e: + # Handle schema mismatch + existing_cols = set(existing_df.columns) - {"__entry_id"} + new_cols = set(polars_data.columns) - {"__entry_id"} + + if existing_cols != new_cols: + raise ValueError( + f"Schema mismatch for {source_key}. " + f"Existing columns: {sorted(existing_cols)}, " + f"New columns: {sorted(new_cols)}" + ) from e + else: + raise e + + return arrow_data + + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: + """Get a specific record.""" + source_key = self._get_source_key(source_name, source_id) + + if source_key not in self._in_memory_store: + return None + + df = self._in_memory_store[source_key] + + # Filter for the specific entry_id + filtered_df = df.filter(pl.col("__entry_id") == entry_id) + + if filtered_df.shape[0] == 0: + return None + + # Remove __entry_id column and convert to Arrow + result_df = filtered_df.drop("__entry_id") + return result_df.to_arrow() + + def get_all_records( + self, source_name: str, source_id: str, add_entry_id_column: bool | str = False + ) -> pa.Table | None: + """Retrieve all records for a given source as a single table.""" + df = self.get_all_records_as_polars( + source_name, source_id, add_entry_id_column=add_entry_id_column + ) + return df.collect().to_arrow() + + def get_all_records_as_polars( + self, source_name: str, source_id: str, add_entry_id_column: bool | str = False + ) -> pl.LazyFrame | None: + """Retrieve all records for a given source as a single Polars LazyFrame.""" + source_key = self._get_source_key(source_name, source_id) + + if source_key not in self._in_memory_store: + return None + + df = self._in_memory_store[source_key] + + if df.shape[0] == 0: + return None + + # perform column selection lazily + df = df.lazy() + + # Handle entry_id column based on parameter + if add_entry_id_column is False: + # Remove __entry_id column + result_df = df.drop("__entry_id") + elif add_entry_id_column is True: + # Keep __entry_id column as is + result_df = df + elif isinstance(add_entry_id_column, str): + # Rename __entry_id to custom name + result_df = df.rename({"__entry_id": add_entry_id_column}) + else: + raise ValueError( + f"add_entry_id_column must be a bool or str but {add_entry_id_column} was given" + ) + + return result_df + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """ + Retrieve records by entry IDs as a single table. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Arrow table containing all found records, or None if no records found + """ + # Convert input to Polars Series + if isinstance(entry_ids, list): + if not entry_ids: + return None + entry_ids_series = pl.Series("entry_id", entry_ids) + elif isinstance(entry_ids, pl.Series): + if len(entry_ids) == 0: + return None + entry_ids_series = entry_ids + elif isinstance(entry_ids, pa.Array): + if len(entry_ids) == 0: + return None + entry_ids_series = pl.from_arrow(pa.table({"entry_id": entry_ids}))[ + "entry_id" + ] + else: + raise TypeError( + f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" + ) + + source_key = self._get_source_key(source_name, source_id) + + if source_key not in self._in_memory_store: + return None + + df = self._in_memory_store[source_key] + + if preserve_input_order: + # Create DataFrame with input order and join to preserve order with nulls + ordered_df = pl.DataFrame({"__entry_id": entry_ids_series}) + result_df = ordered_df.join(df, on="__entry_id", how="left") + else: + # Filter for matching entry_ids (storage order) + result_df = df.filter(pl.col("__entry_id").is_in(entry_ids_series)) + + if result_df.shape[0] == 0: + return None + + # Handle entry_id column based on parameter + if add_entry_id_column is False: + # Remove __entry_id column + result_df = result_df.drop("__entry_id") + elif add_entry_id_column is True: + # Keep __entry_id column as is + pass + elif isinstance(add_entry_id_column, str): + # Rename __entry_id to custom name + result_df = result_df.rename({"__entry_id": add_entry_id_column}) + + return result_df.to_arrow() + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """ + Retrieve records by entry IDs as a single Polars LazyFrame. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Polars LazyFrame containing all found records, or None if no records found + """ + # Get Arrow result and convert to Polars LazyFrame + arrow_result = self.get_records_by_ids( + source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order + ) + + if arrow_result is None: + return None + + # Convert to Polars LazyFrame + return pl.from_arrow(arrow_result).lazy() + + def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: + """Check if a specific entry exists.""" + source_key = self._get_source_key(source_name, source_id) + + if source_key not in self._in_memory_store: + return False + + df = self._in_memory_store[source_key] + return df.filter(pl.col("__entry_id") == entry_id).shape[0] > 0 + + def list_entries(self, source_name: str, source_id: str) -> set[str]: + """List all entry IDs for a specific source.""" + source_key = self._get_source_key(source_name, source_id) + + if source_key not in self._in_memory_store: + return set() + + df = self._in_memory_store[source_key] + return set(df["__entry_id"].to_list()) + + def list_sources(self) -> set[tuple[str, str]]: + """List all (source_name, source_id) combinations.""" + sources = set() + for source_key in self._in_memory_store.keys(): + if ":" in source_key: + source_name, source_id = source_key.split(":", 1) + sources.add((source_name, source_id)) + return sources + + def clear_source(self, source_name: str, source_id: str) -> None: + """Clear all records for a specific source.""" + source_key = self._get_source_key(source_name, source_id) + if source_key in self._in_memory_store: + del self._in_memory_store[source_key] + logger.debug(f"Cleared source {source_key}") + + def clear_all(self) -> None: + """Clear all records from the store.""" + self._in_memory_store.clear() + logger.info("Cleared all records from store") + + def get_stats(self) -> dict[str, Any]: + """Get comprehensive statistics about the data store.""" + total_records = 0 + total_memory_mb = 0 + source_stats = [] + + for source_key, df in self._in_memory_store.items(): + record_count = df.shape[0] + total_records += record_count + + # Estimate memory usage (rough approximation) + memory_bytes = df.estimated_size() + memory_mb = memory_bytes / (1024 * 1024) + total_memory_mb += memory_mb + + source_stats.append( + { + "source_key": source_key, + "record_count": record_count, + "column_count": df.shape[1] - 1, # Exclude __entry_id + "memory_mb": round(memory_mb, 2), + "columns": [col for col in df.columns if col != "__entry_id"], + } + ) + + return { + "total_records": total_records, + "total_sources": len(self._in_memory_store), + "total_memory_mb": round(total_memory_mb, 2), + "duplicate_entry_behavior": self.duplicate_entry_behavior, + "source_details": source_stats, + } + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) demo_single_row_constraint() diff --git a/src/orcapod/store/core.py b/src/orcapod/store/dict_data_stores.py similarity index 100% rename from src/orcapod/store/core.py rename to src/orcapod/store/dict_data_stores.py diff --git a/src/orcapod/store/file.py b/src/orcapod/store/file.py deleted file mode 100644 index 0de8aff..0000000 --- a/src/orcapod/store/file.py +++ /dev/null @@ -1,159 +0,0 @@ -import builtins -import contextlib -import inspect -import os -from pathlib import Path -from typing import Callable, Collection, Dict, Optional, Tuple, Union - -from orcapod.types import Packet, PathSet - - -@contextlib.contextmanager -def redirect_open( - mapping: Union[Dict[str, str], Callable[[str], Optional[str]]], -): - """ - Context manager to intercept file opening operations. - - Args: - mapping: Either a dictionary mapping original paths to their replacements, - or a function that takes a path string and returns a replacement path - (or None to indicate the file should not be opened). - - Raises: - FileNotFoundError: If using a dictionary and the path is not found in it. - """ - # Track all places that might store an open() function - places_to_patch = [] - - # 1. Standard builtins.open - original_builtin_open = builtins.open - places_to_patch.append((builtins, "open", original_builtin_open)) - - # 2. __builtins__ (could be different in some contexts, especially IPython) - if isinstance(__builtins__, dict) and "open" in __builtins__: - places_to_patch.append((__builtins__, "open", __builtins__["open"])) - - # 3. Current module's globals (for the calling namespace) - current_frame = inspect.currentframe() - if current_frame is not None: - caller_globals = current_frame.f_back.f_globals if current_frame.f_back else {} - if "open" in caller_globals: - places_to_patch.append((caller_globals, "open", caller_globals["open"])) - - # 4. Check for IPython user namespace - try: - import IPython - - ip = IPython.get_ipython() # type: ignore - if ip and "open" in ip.user_ns: - places_to_patch.append((ip.user_ns, "open", ip.user_ns["open"])) - except (ImportError, AttributeError): - pass - - def patched_open(file, *args, **kwargs): - # Convert PathLike objects to string if needed - if hasattr(file, "__fspath__"): - file_path = os.fspath(file) - else: - file_path = str(file) - - if isinstance(mapping, dict): - if file_path in mapping: - redirected_path = mapping[file_path] - print(f"Redirecting '{file_path}' to '{redirected_path}'") - return original_builtin_open(redirected_path, *args, **kwargs) - else: - raise FileNotFoundError( - f"Path '{file_path}' not found in redirection mapping" - ) - else: # mapping is a function - redirected_path = mapping(file_path) - if redirected_path is not None: - print(f"Redirecting '{file_path}' to '{redirected_path}'") - return original_builtin_open(redirected_path, *args, **kwargs) - else: - raise FileNotFoundError(f"Path '{file_path}' could not be redirected") - - # Apply the patch to all places - for obj, attr, _ in places_to_patch: - if isinstance(obj, dict): - obj[attr] = patched_open - else: - setattr(obj, attr, patched_open) - - try: - yield - finally: - # Restore all original functions - for obj, attr, original in places_to_patch: - if isinstance(obj, dict): - obj[attr] = original - else: - setattr(obj, attr, original) - - -def virtual_mount( - packet: Packet, -) -> Tuple[Packet, Dict[str, str], Dict[str, str]]: - """ - Visit all pathset within the packet, and convert them to alternative path - representation. By default, full path is mapped to the file name. If two or - more paths have the same file name, the second one is suffixed with "_1", the - third one with "_2", etc. This is useful for creating a virtual mount point - for a set of files, where the original paths are not important, but the file - names can be used to identify the files. - """ - forward_lut = {} # mapping from original path to new path - reverse_lut = {} # mapping from new path to original path - new_packet = {} - - for key, value in packet.items(): - new_packet[key] = convert_pathset(value, forward_lut, reverse_lut) - - return new_packet, forward_lut, reverse_lut - - -# TODO: re-assess the structure of PathSet and consider making it recursive -def convert_pathset(pathset: PathSet, forward_lut, reverse_lut) -> PathSet: - """ - Convert a pathset to a new pathset. forward_lut and reverse_lut are updated - with the new paths. The new paths are created by replacing the original paths - with the new paths in the forward_lut. The reverse_lut is updated with the - original paths. If name already exists, a suffix is added to the new name to avoid - collisions. - """ - if isinstance(pathset, (str, bytes)): - new_name = Path(pathset).name - if new_name in reverse_lut: - # if the name already exists, add a suffix - i = 1 - while f"{new_name}_{i}" in reverse_lut: - i += 1 - new_name = f"{new_name}_{i}" - forward_lut[pathset] = new_name - reverse_lut[new_name] = pathset - return new_name - elif isinstance(pathset, Collection): - return [convert_pathset(p, forward_lut, reverse_lut) for p in pathset] # type: ignore - else: - raise ValueError( - f"Unsupported pathset type: {type(pathset)}. Expected str, bytes, or Collection." - ) - - -class WrappedPath: - def __init__(self, path, name=None): - self.path = Path(path) - if name is None: - name = self.path.name - self.name = name - - def __fspath__(self) -> Union[str, bytes]: - return self.path.__fspath__() - - def __str__(self) -> str: - return self.name - - def __repr__(self) -> str: - return f"WrappedPath({self.path}): {self.name}" diff --git a/src/orcapod/store/file_ops.py b/src/orcapod/store/file_ops.py index 0e34213..4fa6202 100644 --- a/src/orcapod/store/file_ops.py +++ b/src/orcapod/store/file_ops.py @@ -1,10 +1,15 @@ # file_ops.py - Atomic file operations module +import builtins +import contextlib +import inspect import logging import os from pathlib import Path -from orcapod.types import PathLike +from orcapod.types import PathLike, PathSet, Packet +from collections.abc import Collection, Callable + logger = logging.getLogger(__name__) @@ -276,3 +281,154 @@ def is_file_locked(file_path: PathLike) -> bool: except Exception: # Any other exception - assume not locked return False + + +@contextlib.contextmanager +def redirect_open( + mapping: dict[str, str] | Callable[[str], str | None], +): + """ + Context manager to intercept file opening operations. + + Args: + mapping: Either a dictionary mapping original paths to their replacements, + or a function that takes a path string and returns a replacement path + (or None to indicate the file should not be opened). + + Raises: + FileNotFoundError: If using a dictionary and the path is not found in it. + """ + # Track all places that might store an open() function + places_to_patch = [] + + # 1. Standard builtins.open + original_builtin_open = builtins.open + places_to_patch.append((builtins, "open", original_builtin_open)) + + # 2. __builtins__ (could be different in some contexts, especially IPython) + if isinstance(__builtins__, dict) and "open" in __builtins__: + places_to_patch.append((__builtins__, "open", __builtins__["open"])) + + # 3. Current module's globals (for the calling namespace) + current_frame = inspect.currentframe() + if current_frame is not None: + caller_globals = current_frame.f_back.f_globals if current_frame.f_back else {} + if "open" in caller_globals: + places_to_patch.append((caller_globals, "open", caller_globals["open"])) + + # 4. Check for IPython user namespace + try: + import IPython + + ip = IPython.get_ipython() # type: ignore + if ip and "open" in ip.user_ns: + places_to_patch.append((ip.user_ns, "open", ip.user_ns["open"])) + except (ImportError, AttributeError): + pass + + def patched_open(file, *args, **kwargs): + # Convert PathLike objects to string if needed + if hasattr(file, "__fspath__"): + file_path = os.fspath(file) + else: + file_path = str(file) + + if isinstance(mapping, dict): + if file_path in mapping: + redirected_path = mapping[file_path] + print(f"Redirecting '{file_path}' to '{redirected_path}'") + return original_builtin_open(redirected_path, *args, **kwargs) + else: + raise FileNotFoundError( + f"Path '{file_path}' not found in redirection mapping" + ) + else: # mapping is a function + redirected_path = mapping(file_path) + if redirected_path is not None: + print(f"Redirecting '{file_path}' to '{redirected_path}'") + return original_builtin_open(redirected_path, *args, **kwargs) + else: + raise FileNotFoundError(f"Path '{file_path}' could not be redirected") + + # Apply the patch to all places + for obj, attr, _ in places_to_patch: + if isinstance(obj, dict): + obj[attr] = patched_open + else: + setattr(obj, attr, patched_open) + + try: + yield + finally: + # Restore all original functions + for obj, attr, original in places_to_patch: + if isinstance(obj, dict): + obj[attr] = original + else: + setattr(obj, attr, original) + + +def virtual_mount( + packet: Packet, +) -> tuple[Packet, dict[str, str], dict[str, str]]: + """ + Visit all pathset within the packet, and convert them to alternative path + representation. By default, full path is mapped to the file name. If two or + more paths have the same file name, the second one is suffixed with "_1", the + third one with "_2", etc. This is useful for creating a virtual mount point + for a set of files, where the original paths are not important, but the file + names can be used to identify the files. + """ + forward_lut = {} # mapping from original path to new path + reverse_lut = {} # mapping from new path to original path + new_packet = {} + + for key, value in packet.items(): + new_packet[key] = convert_pathset(value, forward_lut, reverse_lut) + + return new_packet, forward_lut, reverse_lut + + +# TODO: re-assess the structure of PathSet and consider making it recursive +def convert_pathset(pathset: PathSet, forward_lut, reverse_lut) -> PathSet: + """ + Convert a pathset to a new pathset. forward_lut and reverse_lut are updated + with the new paths. The new paths are created by replacing the original paths + with the new paths in the forward_lut. The reverse_lut is updated with the + original paths. If name already exists, a suffix is added to the new name to avoid + collisions. + """ + if isinstance(pathset, (str, bytes)): + new_name = Path(pathset).name + if new_name in reverse_lut: + # if the name already exists, add a suffix + i = 1 + while f"{new_name}_{i}" in reverse_lut: + i += 1 + new_name = f"{new_name}_{i}" + forward_lut[pathset] = new_name + reverse_lut[new_name] = pathset + return new_name + elif isinstance(pathset, Collection): + return [convert_pathset(p, forward_lut, reverse_lut) for p in pathset] # type: ignore + else: + raise ValueError( + f"Unsupported pathset type: {type(pathset)}. Expected str, bytes, or Collection." + ) + + +class WrappedPath: + def __init__(self, path, name=None): + self.path = Path(path) + if name is None: + name = self.path.name + self.name = name + + def __fspath__(self) -> str | bytes: + return self.path.__fspath__() + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"WrappedPath({self.path}): {self.name}" diff --git a/src/orcapod/store/optimized_memory_store.py b/src/orcapod/store/optimized_memory_store.py new file mode 100644 index 0000000..ff962e9 --- /dev/null +++ b/src/orcapod/store/optimized_memory_store.py @@ -0,0 +1,433 @@ +import polars as pl +import pyarrow as pa +import logging +from typing import Any, Dict, List, Tuple, cast +from collections import defaultdict + +# Module-level logger +logger = logging.getLogger(__name__) + + +class ArrowBatchedPolarsDataStore: + """ + Arrow-batched Polars data store that minimizes Arrow<->Polars conversions. + + Key optimizations: + 1. Keep data in Arrow format during batching + 2. Only convert to Polars when consolidating or querying + 3. Batch Arrow tables and concatenate before conversion + 4. Maintain Arrow-based indexing for fast lookups + 5. Lazy Polars conversion only when needed + """ + + def __init__(self, duplicate_entry_behavior: str = "error", batch_size: int = 100): + """ + Initialize the ArrowBatchedPolarsDataStore. + + Args: + duplicate_entry_behavior: How to handle duplicate entry_ids: + - 'error': Raise ValueError when entry_id already exists + - 'overwrite': Replace existing entry with new data + batch_size: Number of records to batch before consolidating + """ + if duplicate_entry_behavior not in ["error", "overwrite"]: + raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") + + self.duplicate_entry_behavior = duplicate_entry_behavior + self.batch_size = batch_size + + # Arrow batch buffer: {source_key: [(entry_id, arrow_table), ...]} + self._arrow_batches: Dict[str, List[Tuple[str, pa.Table]]] = defaultdict(list) + + # Consolidated Polars store: {source_key: polars_dataframe} + self._polars_store: Dict[str, pl.DataFrame] = {} + + # Entry ID index for fast lookups: {source_key: set[entry_ids]} + self._entry_index: Dict[str, set] = defaultdict(set) + + # Schema cache + self._schema_cache: Dict[str, pa.Schema] = {} + + logger.info( + f"Initialized ArrowBatchedPolarsDataStore with " + f"duplicate_entry_behavior='{duplicate_entry_behavior}', batch_size={batch_size}" + ) + + def _get_source_key(self, source_name: str, source_id: str) -> str: + """Generate key for source storage.""" + return f"{source_name}:{source_id}" + + def _add_entry_id_to_arrow_table(self, table: pa.Table, entry_id: str) -> pa.Table: + """Add entry_id column to Arrow table efficiently.""" + # Create entry_id array with the same length as the table + entry_id_array = pa.array([entry_id] * len(table), type=pa.string()) + + # Add column at the beginning for consistent ordering + return table.add_column(0, "__entry_id", entry_id_array) + + def _consolidate_arrow_batch(self, source_key: str) -> None: + """Consolidate Arrow batch into Polars DataFrame.""" + if source_key not in self._arrow_batches or not self._arrow_batches[source_key]: + return + + logger.debug(f"Consolidating {len(self._arrow_batches[source_key])} Arrow tables for {source_key}") + + # Prepare all Arrow tables with entry_id columns + arrow_tables_with_id = [] + + for entry_id, arrow_table in self._arrow_batches[source_key]: + table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) + arrow_tables_with_id.append(table_with_id) + + # Concatenate all Arrow tables at once (very fast) + if len(arrow_tables_with_id) == 1: + consolidated_arrow = arrow_tables_with_id[0] + else: + consolidated_arrow = pa.concat_tables(arrow_tables_with_id) + + # Single conversion to Polars + new_polars_df = cast(pl.DataFrame, pl.from_arrow(consolidated_arrow)) + + # Combine with existing Polars DataFrame if it exists + if source_key in self._polars_store: + existing_df = self._polars_store[source_key] + self._polars_store[source_key] = pl.concat([existing_df, new_polars_df]) + else: + self._polars_store[source_key] = new_polars_df + + # Clear the Arrow batch + self._arrow_batches[source_key].clear() + + logger.debug(f"Consolidated to Polars DataFrame with {len(self._polars_store[source_key])} total rows") + + def _force_consolidation(self, source_key: str) -> None: + """Force consolidation of Arrow batches.""" + if source_key in self._arrow_batches and self._arrow_batches[source_key]: + self._consolidate_arrow_batch(source_key) + + def _get_consolidated_dataframe(self, source_key: str) -> pl.DataFrame | None: + """Get consolidated Polars DataFrame, forcing consolidation if needed.""" + self._force_consolidation(source_key) + return self._polars_store.get(source_key) + + def add_record( + self, + source_name: str, + source_id: str, + entry_id: str, + arrow_data: pa.Table, + ) -> pa.Table: + """ + Add a record to the store using Arrow batching. + + This is the fastest path - no conversions, just Arrow table storage. + """ + source_key = self._get_source_key(source_name, source_id) + + # Check for duplicate entry + if entry_id in self._entry_index[source_key]: + if self.duplicate_entry_behavior == "error": + raise ValueError( + f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + else: + # Handle overwrite: remove from both Arrow batch and Polars store + # Remove from Arrow batch + self._arrow_batches[source_key] = [ + (eid, table) for eid, table in self._arrow_batches[source_key] + if eid != entry_id + ] + + # Remove from Polars store if it exists + if source_key in self._polars_store: + self._polars_store[source_key] = self._polars_store[source_key].filter( + pl.col("__entry_id") != entry_id + ) + + # Schema validation (cached) + if source_key in self._schema_cache: + if not self._schema_cache[source_key].equals(arrow_data.schema): + raise ValueError( + f"Schema mismatch for {source_key}. " + f"Expected: {self._schema_cache[source_key]}, " + f"Got: {arrow_data.schema}" + ) + else: + self._schema_cache[source_key] = arrow_data.schema + + # Add to Arrow batch (no conversion yet!) + self._arrow_batches[source_key].append((entry_id, arrow_data)) + self._entry_index[source_key].add(entry_id) + + # Consolidate if batch is full + if len(self._arrow_batches[source_key]) >= self.batch_size: + self._consolidate_arrow_batch(source_key) + + logger.debug(f"Added entry {entry_id} to Arrow batch for {source_key}") + return arrow_data + + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: + """Get a specific record with optimized lookup.""" + source_key = self._get_source_key(source_name, source_id) + + # Quick existence check + if entry_id not in self._entry_index[source_key]: + return None + + # Check Arrow batch first (most recent data) + for batch_entry_id, arrow_table in self._arrow_batches[source_key]: + if batch_entry_id == entry_id: + return arrow_table + + # Check consolidated Polars store + df = self._get_consolidated_dataframe(source_key) + if df is None: + return None + + # Filter and convert back to Arrow + filtered_df = df.filter(pl.col("__entry_id") == entry_id).drop("__entry_id") + + if filtered_df.height == 0: + return None + + return filtered_df.to_arrow() + + def get_all_records( + self, source_name: str, source_id: str, add_entry_id_column: bool | str = False + ) -> pa.Table | None: + """Retrieve all records as a single Arrow table.""" + source_key = self._get_source_key(source_name, source_id) + + # Force consolidation to include all data + df = self._get_consolidated_dataframe(source_key) + if df is None or df.height == 0: + return None + + # Handle entry_id column + if add_entry_id_column is False: + result_df = df.drop("__entry_id") + elif add_entry_id_column is True: + result_df = df + elif isinstance(add_entry_id_column, str): + result_df = df.rename({"__entry_id": add_entry_id_column}) + else: + result_df = df.drop("__entry_id") + + return result_df.to_arrow() + + def get_all_records_as_polars( + self, source_name: str, source_id: str + ) -> pl.LazyFrame | None: + """Retrieve all records as a Polars LazyFrame.""" + source_key = self._get_source_key(source_name, source_id) + + df = self._get_consolidated_dataframe(source_key) + if df is None or df.height == 0: + return None + + return df.drop("__entry_id").lazy() + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """Retrieve records by entry IDs efficiently.""" + # Convert input to list for processing + if isinstance(entry_ids, list): + if not entry_ids: + return None + entry_ids_list = entry_ids + elif isinstance(entry_ids, pl.Series): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_list() + elif isinstance(entry_ids, pa.Array): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_pylist() + else: + raise TypeError(f"entry_ids must be list[str], pl.Series, or pa.Array") + + source_key = self._get_source_key(source_name, source_id) + + # Quick filter using index + existing_entries = [ + entry_id for entry_id in entry_ids_list + if entry_id in self._entry_index[source_key] + ] + + if not existing_entries and not preserve_input_order: + return None + + # Collect from Arrow batch first + batch_tables = [] + found_in_batch = set() + + for entry_id, arrow_table in self._arrow_batches[source_key]: + if entry_id in entry_ids_list: + table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) + batch_tables.append(table_with_id) + found_in_batch.add(entry_id) + + # Get remaining from consolidated store + remaining_ids = [eid for eid in existing_entries if eid not in found_in_batch] + + consolidated_tables = [] + if remaining_ids: + df = self._get_consolidated_dataframe(source_key) + if df is not None: + if preserve_input_order: + ordered_df = pl.DataFrame({"__entry_id": entry_ids_list}) + result_df = ordered_df.join(df, on="__entry_id", how="left") + else: + result_df = df.filter(pl.col("__entry_id").is_in(remaining_ids)) + + if result_df.height > 0: + consolidated_tables.append(result_df.to_arrow()) + + # Combine all results + all_tables = batch_tables + consolidated_tables + + if not all_tables: + return None + + # Concatenate Arrow tables + if len(all_tables) == 1: + result_table = all_tables[0] + else: + result_table = pa.concat_tables(all_tables) + + # Handle entry_id column + if add_entry_id_column is False: + # Remove __entry_id column + column_names = result_table.column_names + if "__entry_id" in column_names: + indices = [i for i, name in enumerate(column_names) if name != "__entry_id"] + result_table = result_table.select(indices) + elif isinstance(add_entry_id_column, str): + # Rename __entry_id column + schema = result_table.schema + new_names = [ + add_entry_id_column if name == "__entry_id" else name + for name in schema.names + ] + result_table = result_table.rename_columns(new_names) + + return result_table + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """Retrieve records by entry IDs as Polars LazyFrame.""" + arrow_result = self.get_records_by_ids( + source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order + ) + + if arrow_result is None: + return None + + pl_result = cast(pl.DataFrame, pl.from_arrow(arrow_result)) + + return pl_result.lazy() + + def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: + """Check if entry exists using the index.""" + source_key = self._get_source_key(source_name, source_id) + return entry_id in self._entry_index[source_key] + + def list_entries(self, source_name: str, source_id: str) -> set[str]: + """List all entry IDs using the index.""" + source_key = self._get_source_key(source_name, source_id) + return self._entry_index[source_key].copy() + + def list_sources(self) -> set[tuple[str, str]]: + """List all source combinations.""" + sources = set() + for source_key in self._entry_index.keys(): + if ":" in source_key: + source_name, source_id = source_key.split(":", 1) + sources.add((source_name, source_id)) + return sources + + def force_consolidation(self) -> None: + """Force consolidation of all Arrow batches.""" + for source_key in list(self._arrow_batches.keys()): + self._force_consolidation(source_key) + logger.info("Forced consolidation of all Arrow batches") + + def clear_source(self, source_name: str, source_id: str) -> None: + """Clear all data for a source.""" + source_key = self._get_source_key(source_name, source_id) + + if source_key in self._arrow_batches: + del self._arrow_batches[source_key] + if source_key in self._polars_store: + del self._polars_store[source_key] + if source_key in self._entry_index: + del self._entry_index[source_key] + if source_key in self._schema_cache: + del self._schema_cache[source_key] + + logger.debug(f"Cleared source {source_key}") + + def clear_all(self) -> None: + """Clear all data.""" + self._arrow_batches.clear() + self._polars_store.clear() + self._entry_index.clear() + self._schema_cache.clear() + logger.info("Cleared all data") + + def get_stats(self) -> dict[str, Any]: + """Get comprehensive statistics.""" + total_records = sum(len(entries) for entries in self._entry_index.values()) + total_batched = sum(len(batch) for batch in self._arrow_batches.values()) + total_consolidated = sum( + len(df) for df in self._polars_store.values() + ) if self._polars_store else 0 + + source_stats = [] + for source_key in self._entry_index.keys(): + record_count = len(self._entry_index[source_key]) + batched_count = len(self._arrow_batches.get(source_key, [])) + consolidated_count = 0 + + if source_key in self._polars_store: + consolidated_count = len(self._polars_store[source_key]) + + source_stats.append({ + "source_key": source_key, + "total_records": record_count, + "batched_records": batched_count, + "consolidated_records": consolidated_count, + }) + + return { + "total_records": total_records, + "total_sources": len(self._entry_index), + "total_batched": total_batched, + "total_consolidated": total_consolidated, + "batch_size": self.batch_size, + "duplicate_entry_behavior": self.duplicate_entry_behavior, + "source_details": source_stats, + } + + def optimize_for_reads(self) -> None: + """Optimize for read operations by consolidating all batches.""" + logger.info("Optimizing for reads - consolidating all Arrow batches...") + self.force_consolidation() + # Clear Arrow batches to save memory + self._arrow_batches.clear() + logger.info("Optimization complete") \ No newline at end of file diff --git a/src/orcapod/store/transfer.py b/src/orcapod/store/transfer_data_store.py similarity index 100% rename from src/orcapod/store/transfer.py rename to src/orcapod/store/transfer_data_store.py diff --git a/src/orcapod/store/types.py b/src/orcapod/store/types.py index 6c1b5af..49d9a70 100644 --- a/src/orcapod/store/types.py +++ b/src/orcapod/store/types.py @@ -45,6 +45,7 @@ def add_record( source_id: str, entry_id: str, arrow_data: pa.Table, + ignore_duplicate: bool = False, ) -> pa.Table: ... def get_record( diff --git a/src/orcapod/types/__init__.py b/src/orcapod/types/__init__.py index f372259..cbcfffc 100644 --- a/src/orcapod/types/__init__.py +++ b/src/orcapod/types/__init__.py @@ -1,52 +1,29 @@ # src/orcabridge/types.py -import os -from collections.abc import Collection, Mapping -from pathlib import Path -from typing import Any, Protocol -from typing_extensions import TypeAlias -from .core import TypeSpec, TypeHandler - - -SUPPORTED_PYTHON_TYPES = (str, int, float, bool, bytes) - -# Convenience alias for anything pathlike -PathLike = str | os.PathLike - -# an (optional) string or a collection of (optional) string values -# Note that TagValue can be nested, allowing for an arbitrary depth of nested lists -TagValue: TypeAlias = str | None | Collection["TagValue"] - -# the top level tag is a mapping from string keys to values that can be a string or -# an arbitrary depth of nested list of strings or None -Tag: TypeAlias = Mapping[str, TagValue] - -# a pathset is a path or an arbitrary depth of nested list of paths -PathSet: TypeAlias = PathLike | Collection[PathLike | None] - -# Simple data types that we support (with clear Polars correspondence) -SupportedNativePythonData: TypeAlias = str | int | float | bool | bytes - -ExtendedSupportedPythonData: TypeAlias = SupportedNativePythonData | PathLike - -# Extended data values that can be stored in packets -# Either the original PathSet or one of our supported simple data types -DataValue: TypeAlias = PathSet | SupportedNativePythonData | Collection["DataValue"] - - -# a packet is a mapping from string keys to data values -Packet: TypeAlias = Mapping[str, DataValue] - -# a batch is a tuple of a tag and a list of packets -Batch: TypeAlias = tuple[Tag, Collection[Packet]] - - -class PodFunction(Protocol): - """ - A function suitable to be used in a FunctionPod. - It takes one or more named arguments, each corresponding to either: - - A path to a file or directory (PathSet) - for backward compatibility - - A simple data value (str, int, float, bool, bytes, Path) - and returns either None, a single value, or a list of values - """ - - def __call__(self, **kwargs: DataValue) -> None | DataValue | list[DataValue]: ... +from .core import Tag, Packet, TypeSpec, PathLike, PathSet, PodFunction +from .registry import TypeRegistry +from .handlers import PathHandler, UUIDHandler, DateTimeHandler +from . import handlers +from . import typespec + + +# Create default registry and register handlers +default_registry = TypeRegistry() + +# Register with semantic names - registry extracts supported types automatically +default_registry.register("path", PathHandler()) +default_registry.register("uuid", UUIDHandler()) +default_registry.register( + "datetime", DateTimeHandler() +) # Registers for datetime, date, time + +__all__ = [ + "default_registry", + "Tag", + "Packet", + "TypeSpec", + "PathLike", + "PathSet", + "PodFunction", + "handlers", + "typespec", +] diff --git a/src/orcapod/types/core.py b/src/orcapod/types/core.py index 5822f87..097750e 100644 --- a/src/orcapod/types/core.py +++ b/src/orcapod/types/core.py @@ -1,6 +1,8 @@ -from typing import Protocol, Any, TypeAlias, Mapping +from typing import Protocol, Any, TypeAlias import pyarrow as pa from dataclasses import dataclass +import os +from collections.abc import Collection, Mapping # TODO: reconsider the need for this dataclass as its information is superfluous @@ -20,6 +22,51 @@ class TypeInfo: ] # Mapping of parameter names to their types +SUPPORTED_PYTHON_TYPES = (str, int, float, bool, bytes) + +# Convenience alias for anything pathlike +PathLike = str | os.PathLike + +# an (optional) string or a collection of (optional) string values +# Note that TagValue can be nested, allowing for an arbitrary depth of nested lists +TagValue: TypeAlias = str | None | Collection["TagValue"] + +# the top level tag is a mapping from string keys to values that can be a string or +# an arbitrary depth of nested list of strings or None +Tag: TypeAlias = Mapping[str, TagValue] + +# a pathset is a path or an arbitrary depth of nested list of paths +PathSet: TypeAlias = PathLike | Collection[PathLike | None] + +# Simple data types that we support (with clear Polars correspondence) +SupportedNativePythonData: TypeAlias = str | int | float | bool | bytes + +ExtendedSupportedPythonData: TypeAlias = SupportedNativePythonData | PathLike + +# Extended data values that can be stored in packets +# Either the original PathSet or one of our supported simple data types +DataValue: TypeAlias = PathSet | SupportedNativePythonData | Collection["DataValue"] + + +# a packet is a mapping from string keys to data values +Packet: TypeAlias = Mapping[str, DataValue] + +# a batch is a tuple of a tag and a list of packets +Batch: TypeAlias = tuple[Tag, Collection[Packet]] + + +class PodFunction(Protocol): + """ + A function suitable to be used in a FunctionPod. + It takes one or more named arguments, each corresponding to either: + - A path to a file or directory (PathSet) - for backward compatibility + - A simple data value (str, int, float, bool, bytes, Path) + and returns either None, a single value, or a list of values + """ + + def __call__(self, **kwargs: DataValue) -> None | DataValue | list[DataValue]: ... + + class TypeHandler(Protocol): """Protocol for handling conversion between Python types and underlying Arrow data types used for storage. diff --git a/src/orcapod/types/default.py b/src/orcapod/types/default.py deleted file mode 100644 index d41e577..0000000 --- a/src/orcapod/types/default.py +++ /dev/null @@ -1,18 +0,0 @@ -from .registry import TypeRegistry -from .handlers import ( - PathHandler, - UUIDHandler, - SimpleMappingHandler, - DateTimeHandler, -) -import pyarrow as pa - -# Create default registry and register handlers -default_registry = TypeRegistry() - -# Register with semantic names - registry extracts supported types automatically -default_registry.register("path", PathHandler()) -default_registry.register("uuid", UUIDHandler()) -default_registry.register( - "datetime", DateTimeHandler() -) # Registers for datetime, date, time diff --git a/src/orcapod/types/registry.py b/src/orcapod/types/registry.py index 0dafda5..6b56183 100644 --- a/src/orcapod/types/registry.py +++ b/src/orcapod/types/registry.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Collection, Sequence +from collections.abc import Callable, Collection, Sequence, Mapping import logging from optparse import Values from typing import Any @@ -156,9 +156,7 @@ def _to_storage_packet(self, packet: Packet) -> dict[str, Any]: self._check_key_consistency(packet_keys) # Convert each value - storage_packet: dict[str, Any] = ( - packet.copy() - ) # Start with a copy of the packet + storage_packet: dict[str, Any] = dict(packet) # Start with a copy of the packet for key, handler in self.keys_with_handlers: try: @@ -168,7 +166,7 @@ def _to_storage_packet(self, packet: Packet) -> dict[str, Any]: return storage_packet - def _from_storage_packet(self, storage_packet: dict[str, Any]) -> Packet: + def _from_storage_packet(self, storage_packet: Mapping[str, Any]) -> Packet: """Convert storage packet back to Python packet. Args: @@ -188,7 +186,7 @@ def _from_storage_packet(self, storage_packet: dict[str, Any]) -> Packet: self._check_key_consistency(storage_keys) # Convert each value back to Python type - packet: Packet = storage_packet.copy() + packet: Packet = dict(storage_packet) for key, handler in self.keys_with_handlers: try: diff --git a/src/orcapod/types/inference.py b/src/orcapod/types/typespec.py similarity index 98% rename from src/orcapod/types/inference.py rename to src/orcapod/types/typespec.py index 2f18f39..eb5be89 100644 --- a/src/orcapod/types/inference.py +++ b/src/orcapod/types/typespec.py @@ -1,20 +1,20 @@ -# Library of functions for inferring types for FunctionPod input and output parameters. +# Library of functions for working with TypeSpecs and for extracting TypeSpecs from a function's signature from collections.abc import Callable, Collection, Sequence -from typing import get_origin, get_args, TypeAlias +from typing import get_origin, get_args from .core import TypeSpec import inspect import logging - logger = logging.getLogger(__name__) def verify_against_typespec(packet: dict, typespec: TypeSpec) -> bool: """Verify that the dictionary's types match the expected types in the typespec.""" from beartype.door import is_bearable + # verify that packet contains no keys not in typespec if set(packet.keys()) - set(typespec.keys()): logger.warning( @@ -40,6 +40,7 @@ def check_typespec_compatibility( incoming_types: TypeSpec, receiving_types: TypeSpec ) -> bool: from beartype.door import is_subhint + for key, type_info in incoming_types.items(): if key not in receiving_types: logger.warning(f"Key '{key}' not found in parameter types.") @@ -52,7 +53,7 @@ def check_typespec_compatibility( return True -def extract_function_data_types( +def extract_function_typespecs( func: Callable, output_keys: Collection[str], input_types: TypeSpec | None = None, diff --git a/uv.lock b/uv.lock index 589ebc2..ba522ac 100644 --- a/uv.lock +++ b/uv.lock @@ -1230,7 +1230,7 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.10.3" }, { name = "networkx" }, { name = "pandas", specifier = ">=2.2.3" }, - { name = "polars", specifier = ">=1.30.0" }, + { name = "polars", specifier = ">=1.31.0" }, { name = "pyarrow", specifier = ">=20.0.0" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "redis", marker = "extra == 'redis'", specifier = ">=6.2.0" }, @@ -1436,16 +1436,16 @@ wheels = [ [[package]] name = "polars" -version = "1.30.0" +version = "1.31.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/82/b6/8dbdf626c0705a57f052708c9fc0860ffc2aa97955930d5faaf6a66fcfd3/polars-1.30.0.tar.gz", hash = "sha256:dfe94ae84a5efd9ba74e616e3e125b24ca155494a931890a8f17480737c4db45", size = 4668318, upload-time = "2025-05-21T13:33:24.175Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/f5/de1b5ecd7d0bd0dd87aa392937f759f9cc3997c5866a9a7f94eabf37cd48/polars-1.31.0.tar.gz", hash = "sha256:59a88054a5fc0135386268ceefdbb6a6cc012d21b5b44fed4f1d3faabbdcbf32", size = 4681224, upload-time = "2025-06-18T12:00:46.24Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/48/e9b2cb379abcc9f7aff2e701098fcdb9fe6d85dc4ad4cec7b35d39c70951/polars-1.30.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:4c33bc97c29b7112f0e689a2f8a33143973a3ff466c70b25c7fd1880225de6dd", size = 35704342, upload-time = "2025-05-21T13:32:22.996Z" }, - { url = "https://files.pythonhosted.org/packages/36/ca/f545f61282f75eea4dfde4db2944963dcd59abd50c20e33a1c894da44dad/polars-1.30.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e3d05914c364b8e39a5b10dcf97e84d76e516b3b1693880bf189a93aab3ca00d", size = 32459857, upload-time = "2025-05-21T13:32:27.728Z" }, - { url = "https://files.pythonhosted.org/packages/76/20/e018cd87d7cb6f8684355f31f4e193222455a6e8f7b942f4a2934f5969c7/polars-1.30.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a52af3862082b868c1febeae650af8ae8a2105d2cb28f0449179a7b44f54ccf", size = 36267243, upload-time = "2025-05-21T13:32:31.796Z" }, - { url = "https://files.pythonhosted.org/packages/cb/e7/b88b973021be07b13d91b9301cc14392c994225ef5107a32a8ffd3fd6424/polars-1.30.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:ffb3ef133454275d4254442257c5f71dd6e393ce365c97997dadeb6fa9d6d4b5", size = 33416871, upload-time = "2025-05-21T13:32:35.077Z" }, - { url = "https://files.pythonhosted.org/packages/dd/7c/d46d4381adeac537b8520b653dc30cb8b7edbf59883d71fbb989e9005de1/polars-1.30.0-cp39-abi3-win_amd64.whl", hash = "sha256:c26b633a9bd530c5fc09d317fca3bb3e16c772bd7df7549a9d8ec1934773cc5d", size = 36363630, upload-time = "2025-05-21T13:32:38.286Z" }, - { url = "https://files.pythonhosted.org/packages/fb/b5/5056d0c12aadb57390d0627492bef8b1abf3549474abb9ae0fd4e2bfa885/polars-1.30.0-cp39-abi3-win_arm64.whl", hash = "sha256:476f1bde65bc7b4d9f80af370645c2981b5798d67c151055e58534e89e96f2a8", size = 32643590, upload-time = "2025-05-21T13:32:42.107Z" }, + { url = "https://files.pythonhosted.org/packages/3d/6e/bdd0937653c1e7a564a09ae3bc7757ce83fedbf19da600c8b35d62c0182a/polars-1.31.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ccc68cd6877deecd46b13cbd2663ca89ab2a2cb1fe49d5cfc66a9cef166566d9", size = 34511354, upload-time = "2025-06-18T11:59:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/77/fe/81aaca3540c1a5530b4bc4fd7f1b6f77100243d7bb9b7ad3478b770d8b3e/polars-1.31.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:a94c5550df397ad3c2d6adc212e59fd93d9b044ec974dd3653e121e6487a7d21", size = 31377712, upload-time = "2025-06-18T11:59:45.104Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d9/5e2753784ea30d84b3e769a56f5e50ac5a89c129e87baa16ac0773eb4ef7/polars-1.31.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada7940ed92bea65d5500ae7ac1f599798149df8faa5a6db150327c9ddbee4f1", size = 35050729, upload-time = "2025-06-18T11:59:48.538Z" }, + { url = "https://files.pythonhosted.org/packages/20/e8/a6bdfe7b687c1fe84bceb1f854c43415eaf0d2fdf3c679a9dc9c4776e462/polars-1.31.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:b324e6e3e8c6cc6593f9d72fe625f06af65e8d9d47c8686583585533a5e731e1", size = 32260836, upload-time = "2025-06-18T11:59:52.543Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f6/9d9ad9dc4480d66502497e90ce29efc063373e1598f4bd9b6a38af3e08e7/polars-1.31.0-cp39-abi3-win_amd64.whl", hash = "sha256:3fd874d3432fc932863e8cceff2cff8a12a51976b053f2eb6326a0672134a632", size = 35156211, upload-time = "2025-06-18T11:59:55.805Z" }, + { url = "https://files.pythonhosted.org/packages/40/4b/0673a68ac4d6527fac951970e929c3b4440c654f994f0c957bd5556deb38/polars-1.31.0-cp39-abi3-win_arm64.whl", hash = "sha256:62ef23bb9d10dca4c2b945979f9a50812ac4ace4ed9e158a6b5d32a7322e6f75", size = 31469078, upload-time = "2025-06-18T11:59:59.242Z" }, ] [[package]] From 09eb9473cd35af5169b2c267f0ab8421353ecf29 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 24 Jun 2025 22:26:10 +0000 Subject: [PATCH 04/57] refactor: implement ContentHashableBase --- src/orcapod/core/base.py | 21 ++++----- src/orcapod/core/pod.py | 5 ++- src/orcapod/hashing/__init__.py | 2 + src/orcapod/hashing/content_hashable.py | 57 +++++++++++++++++++++++++ src/orcapod/hashing/core.py | 16 ++++--- src/orcapod/hashing/types.py | 2 + 6 files changed, 82 insertions(+), 21 deletions(-) create mode 100644 src/orcapod/hashing/content_hashable.py diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 664352d..0dd45ad 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -5,7 +5,10 @@ from typing import Any -from orcapod.hashing import HashableMixin +from orcapod.hashing import HashableMixin, ObjectHasher +from orcapod.hashing import get_default_object_hasher + +from orcapod.hashing import ContentHashableBase from orcapod.types import Packet, Tag, TypeSpec from orcapod.utils.stream_utils import get_typespec @@ -15,7 +18,7 @@ logger = logging.getLogger(__name__) -class Kernel(ABC, HashableMixin): +class Kernel(ABC, ContentHashableBase): """ Kernel defines the fundamental unit of computation that can be performed on zero, one or more streams of data. It is the base class for all computations and transformations that can be performed on a collection of streams @@ -27,7 +30,7 @@ class Kernel(ABC, HashableMixin): for computational graph tracking. """ - def __init__(self, label: str | None = None, skip_tracking: bool = False, **kwargs) -> None: + def __init__(self, label: str | None = None, skip_tracking: bool = False,**kwargs) -> None: super().__init__(**kwargs) self._label = label self._skip_tracking = skip_tracking @@ -227,7 +230,7 @@ def record(self, invocation: "Invocation") -> None: ... # This is NOT an abstract class, but rather a concrete class that # represents an invocation of a kernel on a collection of streams. -class Invocation(HashableMixin): +class Invocation(ContentHashableBase): """ This class represents an invocation of a kernel on a collection of streams. It contains the kernel and the streams that were used in the invocation. @@ -244,20 +247,12 @@ def __init__( self.kernel = kernel self.streams = streams - def __hash__(self) -> int: - return super().__hash__() - def __repr__(self) -> str: return f"Invocation(kernel={self.kernel}, streams={self.streams})" def __str__(self) -> str: return f"Invocation[ID:{self.__hash__()}]({self.kernel}, {self.streams})" - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Invocation): - return False - return hash(self) == hash(other) - def __lt__(self, other: Any) -> bool: if not isinstance(other, Invocation): return NotImplemented @@ -294,7 +289,7 @@ def identity_structure(self) -> int: return self.kernel.identity_structure(*self.streams) -class Stream(ABC, HashableMixin): +class Stream(ABC, ContentHashableBase): """ A stream is a collection of tagged-packets that are generated by an operation. The stream is iterable and can be used to access the packets in the stream. diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index 582fa85..77d1610 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -228,8 +228,9 @@ def __repr__(self) -> str: return f"FunctionPod:{self.function!r}" def __str__(self) -> str: - func_sig = get_function_signature(self.function) - return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" + include_module = self.function.__module__ != "__main__" + func_sig = get_function_signature(self.function, name_override=self.function_name, include_module=include_module) + return f"FunctionPod:{func_sig}" def call(self, tag, packet) -> tuple[Tag, Packet | None]: if not self.is_active(): diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index 98a15da..f95b0f7 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -24,6 +24,7 @@ FunctionInfoExtractor, CompositeFileHasher, ) +from .content_hashable import ContentHashableBase __all__ = [ "FileHasher", @@ -46,4 +47,5 @@ "get_default_composite_file_hasher", "get_default_object_hasher", "get_default_arrow_hasher", + "ContentHashableBase", ] diff --git a/src/orcapod/hashing/content_hashable.py b/src/orcapod/hashing/content_hashable.py new file mode 100644 index 0000000..f3fc929 --- /dev/null +++ b/src/orcapod/hashing/content_hashable.py @@ -0,0 +1,57 @@ + +from .types import ObjectHasher +from .defaults import get_default_object_hasher +from typing import Any + +class ContentHashableBase: + def __init__(self, object_hasher: ObjectHasher | None = None) -> None: + """ + Initialize the ContentHashable with an optional ObjectHasher. + + Args: + object_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. + """ + self.object_hasher = object_hasher or get_default_object_hasher() + + + def identity_structure(self) -> Any: + """ + Return a structure that represents the identity of this object. + + Override this method in your subclass to provide a stable representation + of your object's content. The structure should contain all fields that + determine the object's identity. + + Returns: + Any: A structure representing this object's content, or None to use default hash + """ + return None + + + def __hash__(self) -> int: + """ + Hash implementation that uses the identity structure if provided, + otherwise falls back to the superclass's hash method. + + Returns: + int: A hash value based on either content or identity + """ + # Get the identity structure + structure = self.identity_structure() + + return self.object_hasher.hash_to_int(structure) + + def __eq__(self, other: object) -> bool: + """ + Equality check that compares the identity structures of two objects. + + Args: + other (object): The object to compare against. + + Returns: + bool: True if both objects have the same identity structure, False otherwise. + """ + if not isinstance(other, ContentHashableBase): + return NotImplemented + + return self.identity_structure() == other.identity_structure() \ No newline at end of file diff --git a/src/orcapod/hashing/core.py b/src/orcapod/hashing/core.py index 66b4e4d..6e40bde 100644 --- a/src/orcapod/hashing/core.py +++ b/src/orcapod/hashing/core.py @@ -832,6 +832,7 @@ def get_function_signature( name_override: str | None = None, include_defaults: bool = True, include_module: bool = True, + output_names: Collection[str] | None = None ) -> str: """ Get a stable string representation of a function's signature. @@ -847,14 +848,14 @@ def get_function_signature( sig = inspect.signature(func) # Build the signature string - parts = [] + parts = {} # Add module if requested if include_module and hasattr(func, "__module__"): - parts.append(f"module:{func.__module__}") + parts["module"] = func.__module__ # Add function name - parts.append(f"name:{name_override or func.__name__}") + parts["name"] = name_override or func.__name__ # Add parameters param_strs = [] @@ -864,13 +865,16 @@ def get_function_signature( param_str = param_str.split("=")[0].strip() param_strs.append(param_str) - parts.append(f"params:({', '.join(param_strs)})") + parts["params"] = f"({', '.join(param_strs)})" # Add return annotation if present if sig.return_annotation is not inspect.Signature.empty: - parts.append(f"returns:{sig.return_annotation}") + parts["returns"] = sig.return_annotation - return " ".join(parts) + fn_string = f"{parts["module"] + "." if "module" in parts else ""}{parts["name"]}{parts["params"]}" + if "returns" in parts: + fn_string = fn_string + f"-> {str(parts["returns"])}" + return fn_string def _is_in_string(line, pos): diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index 36155bb..abae409 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -140,3 +140,5 @@ def extract_function_info( input_typespec: TypeSpec | None = None, output_typespec: TypeSpec | None = None, ) -> dict[str, Any]: ... + + From bd3c7a871641fc128ce2d3e37bede5b14f62b301 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 25 Jun 2025 03:10:29 +0000 Subject: [PATCH 05/57] refactor: significantly clean up label logic --- src/orcapod/core/base.py | 64 +++++++++---------------- src/orcapod/core/operators.py | 14 +++++- src/orcapod/core/streams.py | 35 ++++++++++++-- src/orcapod/core/tracker.py | 47 ++++++++++++++---- src/orcapod/hashing/content_hashable.py | 31 +++++++++++- src/orcapod/hashing/core.py | 6 +++ src/orcapod/pipeline/pipeline.py | 52 +++++++++++++++----- src/orcapod/pipeline/wrappers.py | 14 +++--- 8 files changed, 184 insertions(+), 79 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 0dd45ad..9aa8b4d 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -30,24 +30,12 @@ class Kernel(ABC, ContentHashableBase): for computational graph tracking. """ - def __init__(self, label: str | None = None, skip_tracking: bool = False,**kwargs) -> None: + def __init__(self, label: str | None = None, skip_tracking: bool = False, **kwargs) -> None: super().__init__(**kwargs) self._label = label self._skip_tracking = skip_tracking - @property - def label(self) -> str: - """ - Returns a human-readable label for this kernel. - Default implementation returns the provided label or class name if no label was provided. - """ - if self._label: - return self._label - return self.__class__.__name__ - - @label.setter - def label(self, label: str) -> None: - self._label = label + def pre_forward_hook( self, *streams: "SyncStream", **kwargs @@ -68,7 +56,9 @@ def post_forward_hook(self, output_stream: "SyncStream", **kwargs) -> "SyncStrea return output_stream - def __call__(self, *streams: "SyncStream", **kwargs) -> "SyncStream": + def __call__(self, *streams: "SyncStream", label:str|None = None, **kwargs) -> "SyncStream": + if label is not None: + self.label = label # Special handling of Source: trigger call on source if passed as stream normalized_streams = [ stream() if isinstance(stream, Source) else stream for stream in streams @@ -243,10 +233,19 @@ def __init__( kernel: Kernel, # TODO: technically this should be Stream to stay consistent with Stream interface. Update to Stream when AsyncStream is implemented streams: Collection["SyncStream"], + **kwargs, ) -> None: + super().__init__(**kwargs) self.kernel = kernel self.streams = streams + def computed_label(self) -> str | None: + """ + Returns the computed label for this invocation. + This is used to provide a default label if no label is set. + """ + return self.kernel.label + def __repr__(self) -> str: return f"Invocation(kernel={self.kernel}, streams={self.streams})" @@ -298,35 +297,16 @@ class Stream(ABC, ContentHashableBase): This may be None if the stream is not generated by a kernel (i.e. directly instantiated by a user). """ - def __init__(self, label: str | None = None, **kwargs) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._invocation: Invocation | None = None - self._label = label - - @property - def label(self) -> str: - """ - Returns a human-readable label for this stream. - If no label is provided and the stream is generated by an operation, - the label of the operation is used. - Otherwise, the class name is used as the label. - """ - if self._label is None: - if self.invocation is not None: - # use the invocation operation label - return self.invocation.kernel.label - else: - return self.__class__.__name__ - return self._label - @label.setter - def label(self, label: str) -> None: - """ - Sets a human-readable label for this stream. - """ - if not isinstance(label, str): - raise TypeError("label must be a string") - self._label = label + def computed_label(self) -> str | None: + if self.invocation is not None: + # use the invocation operation label + return self.invocation.kernel.label + return None + @property def invocation(self) -> Invocation | None: @@ -347,7 +327,7 @@ def flow(self) -> Collection[tuple[Tag, Packet]]: Flow everything through the stream, returning the entire collection of (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. """ - return list(self) + return [e for e in self] # --------------------- Recursive methods --------------------------- # These methods form a step in the multi-class recursive invocation that follows the pattern of diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index 84a31f3..654d9d2 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -438,7 +438,12 @@ def keys( stream = streams[0] tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) if tag_keys is None or packet_keys is None: - return super().keys(trigger_run=trigger_run) + super_tag_keys, super_packet_keys = super().keys(trigger_run=trigger_run) + tag_keys = tag_keys or super_tag_keys + packet_keys = packet_keys or super_packet_keys + + if packet_keys is None: + return tag_keys, packet_keys if self.drop_unmapped: # If drop_unmapped is True, we only keep the keys that are in the mapping @@ -464,7 +469,12 @@ def types( stream = streams[0] tag_types, packet_types = stream.types(trigger_run=trigger_run) if tag_types is None or packet_types is None: - return super().types(trigger_run=trigger_run) + super_tag_types, super_packet_types = super().types(trigger_run=trigger_run) + tag_types = tag_types or super_tag_types + packet_types = packet_types or super_packet_types + + if packet_types is None: + return tag_types, packet_types if self.drop_unmapped: # If drop_unmapped is True, we only keep the keys that are in the mapping diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index 77cdbe3..f5c5e60 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -1,7 +1,8 @@ from collections.abc import Callable, Collection, Iterator from orcapod.core.base import SyncStream -from orcapod.types import Packet, Tag +from orcapod.types import Packet, Tag, TypeSpec +from copy import copy class SyncStreamFromLists(SyncStream): @@ -12,12 +13,21 @@ def __init__( paired: Collection[tuple[Tag, Packet]] | None = None, tag_keys: list[str] | None = None, packet_keys: list[str] | None = None, + tag_typespec: TypeSpec | None = None, + packet_typespec: TypeSpec | None = None, strict: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) + self.tag_typespec = tag_typespec + self.packet_typespec = packet_typespec + if tag_keys is None and tag_typespec is not None: + tag_keys = list(tag_typespec.keys()) + if packet_keys is None and packet_typespec is not None: + packet_keys = list(packet_typespec.keys()) self.tag_keys = tag_keys self.packet_keys = packet_keys + if tags is not None and packets is not None: if strict and len(tags) != len(packets): raise ValueError( @@ -34,14 +44,31 @@ def __init__( def keys( self, *, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: - if self.tag_keys is None or self.packet_keys is None: - return super().keys(trigger_run=trigger_run) + tag_keys, packet_keys = copy(self.tag_keys), copy(self.packet_keys) + if tag_keys is None or packet_keys is None: + super_tag_keys, super_packet_keys = super().keys(trigger_run=trigger_run) + tag_keys = tag_keys or super_tag_keys + packet_keys = packet_keys or super_packet_keys + # If the keys are already set, return them - return self.tag_keys.copy(), self.packet_keys.copy() + return tag_keys, packet_keys + + def types( + self, *, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + tag_typespec, packet_typespec = copy(self.tag_typespec), copy(self.packet_typespec) + if tag_typespec is None or packet_typespec is None: + super_tag_typespec, super_packet_typespec = super().types(trigger_run=trigger_run) + tag_typespec = tag_typespec or super_tag_typespec + packet_typespec = packet_typespec or super_packet_typespec + + # If the types are already set, return them + return tag_typespec, packet_typespec def __iter__(self) -> Iterator[tuple[Tag, Packet]]: yield from self.paired + class SyncStreamFromGenerator(SyncStream): """ diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py index 2532582..337c027 100644 --- a/src/orcapod/core/tracker.py +++ b/src/orcapod/core/tracker.py @@ -1,8 +1,37 @@ -from orcapod.core.base import Invocation, Kernel, Tracker, SyncStream, TypeSpec -from collections.abc import Collection +from orcapod.core.base import Invocation, Kernel, Tracker, SyncStream, Source +from orcapod.types import Tag, Packet, TypeSpec +from collections.abc import Collection, Iterator from typing import Any -class StubKernel(Kernel): +class StreamWrapper(SyncStream): + """ + A wrapper for a SyncStream that allows it to be used as a Source. + This is useful for cases where you want to treat a stream as a source + without modifying the original stream. + """ + + def __init__(self, stream: SyncStream, **kwargs): + super().__init__(**kwargs) + self.stream = stream + + def keys(self, *streams: SyncStream, **kwargs) -> tuple[Collection[str]|None, Collection[str]|None]: + return self.stream.keys(*streams, **kwargs) + + def types(self, *streams: SyncStream, **kwargs) -> tuple[TypeSpec|None, TypeSpec|None]: + return self.stream.types(*streams, **kwargs) + + def computed_label(self) -> str | None: + return self.stream.label + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + """ + Iterate over the stream, yielding tuples of (tags, packets). + """ + yield from self.stream + + + +class StreamSource(Source): def __init__(self, stream: SyncStream, **kwargs): super().__init__(skip_tracking=True, **kwargs) self.stream = stream @@ -10,15 +39,15 @@ def __init__(self, stream: SyncStream, **kwargs): def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 0: raise ValueError( - "StubKernel does not support forwarding streams. " + "StreamSource does not support forwarding streams. " "It generates its own stream from the file system." ) - return self.stream + return StreamWrapper(self.stream) def identity_structure(self, *streams) -> Any: if len(streams) != 0: raise ValueError( - "StubKernel does not support forwarding streams. " + "StreamSource does not support forwarding streams. " "It generates its own stream from the file system." ) @@ -29,11 +58,11 @@ def types(self, *streams: SyncStream, **kwargs) -> tuple[TypeSpec|None, TypeSpec def keys(self, *streams: SyncStream, **kwargs) -> tuple[Collection[str]|None, Collection[str]|None]: return self.stream.keys() - + def computed_label(self) -> str | None: + return self.stream.label - class GraphTracker(Tracker): """ @@ -89,7 +118,7 @@ def generate_graph(self): upstream_invocation = upstream.invocation if upstream_invocation is None: # If upstream is None, create a stub kernel - upstream_invocation = Invocation(StubKernel(upstream, label="StubInput"), []) + upstream_invocation = Invocation(StreamSource(upstream), []) if upstream_invocation not in G: G.add_node(upstream_invocation) G.add_edge(upstream_invocation, invocation, stream=upstream) diff --git a/src/orcapod/hashing/content_hashable.py b/src/orcapod/hashing/content_hashable.py index f3fc929..33fd1bb 100644 --- a/src/orcapod/hashing/content_hashable.py +++ b/src/orcapod/hashing/content_hashable.py @@ -3,8 +3,9 @@ from .defaults import get_default_object_hasher from typing import Any + class ContentHashableBase: - def __init__(self, object_hasher: ObjectHasher | None = None) -> None: + def __init__(self, object_hasher: ObjectHasher | None = None, label: str | None = None) -> None: """ Initialize the ContentHashable with an optional ObjectHasher. @@ -12,6 +13,31 @@ def __init__(self, object_hasher: ObjectHasher | None = None) -> None: object_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. """ self.object_hasher = object_hasher or get_default_object_hasher() + self._label = label + + @property + def label(self) -> str : + """ + Get the label of this object. + + Returns: + str | None: The label of the object, or None if not set. + """ + return self._label or self.computed_label() or self.__class__.__name__ + + @label.setter + def label(self, label: str | None) -> None: + """ + Set the label of this object. + + Args: + label (str | None): The label to set for this object. + """ + self._label = label + + def computed_label(self) -> str|None: + return None + def identity_structure(self) -> Any: @@ -38,6 +64,9 @@ def __hash__(self) -> int: """ # Get the identity structure structure = self.identity_structure() + if structure is None: + # If no identity structure is provided, use the default hash + return super().__hash__() return self.object_hasher.hash_to_int(structure) diff --git a/src/orcapod/hashing/core.py b/src/orcapod/hashing/core.py index 6e40bde..d37ade3 100644 --- a/src/orcapod/hashing/core.py +++ b/src/orcapod/hashing/core.py @@ -29,6 +29,7 @@ ) from uuid import UUID + import xxhash from orcapod.types import Packet, PathSet @@ -435,6 +436,11 @@ def process_structure( if isinstance(obj, HashableMixin): logger.debug(f"Processing HashableMixin instance of type {type(obj).__name__}") return obj.content_hash() + + from .content_hashable import ContentHashableBase + if isinstance(obj, ContentHashableBase): + logger.debug(f"Processing ContentHashableBase instance of type {type(obj).__name__}") + return process_structure(obj.identity_structure(), visited, function_info_extractor) # Handle basic types if isinstance(obj, (str, int, float, bool)): diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 5df050f..7edd03e 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -1,15 +1,12 @@ from collections import defaultdict -from collections.abc import Collection, Iterator -import json +from collections.abc import Collection import logging import pickle import sys import time -from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Protocol, runtime_checkable +from typing import Any -import pandas as pd from orcapod.core import Invocation, Kernel, SyncStream from orcapod.core.pod import FunctionPod @@ -17,13 +14,7 @@ from orcapod.hashing import hash_to_hex from orcapod.core.tracker import GraphTracker -from orcapod.hashing import ObjectHasher, ArrowHasher -from orcapod.types import TypeSpec, Tag, Packet -from orcapod.core.streams import SyncStreamFromGenerator from orcapod.store import ArrowDataStore -from orcapod.types.registry import PacketConverter, TypeRegistry -from orcapod.types import default_registry -from orcapod.utils.stream_utils import merge_typespecs, get_typespec logger = logging.getLogger(__name__) @@ -40,12 +31,15 @@ class Pipeline(GraphTracker): Replaces the old Tracker with better persistence and view capabilities. """ - def __init__(self, name: str, results_store: ArrowDataStore, pipeline_store: ArrowDataStore) -> None: + def __init__(self, name: str, results_store: ArrowDataStore, pipeline_store: ArrowDataStore, auto_compile:bool=True) -> None: super().__init__() self.name = name or f"pipeline_{id(self)}" self.results_store = results_store self.pipeline_store = pipeline_store self.labels_to_nodes = {} + self.auto_compile = auto_compile + self._dirty = False + self._ordered_nodes = [] # Track order of invocations # Core Pipeline Operations def save(self, path: Path | str) -> None: @@ -77,6 +71,14 @@ def save(self, path: Path | str) -> None: temp_path.unlink() raise + def record(self, invocation: Invocation) -> None: + """ + Record an invocation in the pipeline. + This method is called automatically by the Kernel when an operation is invoked. + """ + super().record(invocation) + self._dirty = True + def wrap_invocation( self, kernel: Kernel, input_nodes: Collection[Node] ) -> Node: @@ -93,6 +95,7 @@ def compile(self): proposed_labels = defaultdict(list) node_lut = {} edge_lut : dict[SyncStream, Node]= {} + ordered_nodes = [] for invocation in nx.topological_sort(G): # map streams to the new streams based on Nodes input_nodes = [edge_lut[stream] for stream in invocation.streams] @@ -100,11 +103,14 @@ def compile(self): # register the new node against the original invocation node_lut[invocation] = new_node + ordered_nodes.append(new_node) # register the new node in the proposed labels -- if duplicates occur, will resolve later proposed_labels[new_node.label].append(new_node) for edge in G.out_edges(invocation): edge_lut[G.edges[edge]["stream"]] = new_node + + self._ordered_nodes = ordered_nodes # resolve duplicates in proposed_labels labels_to_nodes = {} @@ -120,8 +126,15 @@ def compile(self): labels_to_nodes[label] = nodes[0] self.labels_to_nodes = labels_to_nodes + self._dirty = False return node_lut, edge_lut, proposed_labels, labels_to_nodes + def __exit__(self, exc_type, exc_val, ext_tb): + super().__exit__(exc_type, exc_val, ext_tb) + if self.auto_compile: + self.compile() + + def __getattr__(self, item: str) -> Any: """Allow direct access to pipeline attributes""" if item in self.labels_to_nodes: @@ -131,8 +144,21 @@ def __getattr__(self, item: str) -> Any: def __dir__(self): # Include both regular attributes and dynamic ones return list(super().__dir__()) + list(self.labels_to_nodes.keys()) - + def run(self, full_sync:bool=False) -> None: + """ + Run the pipeline, compiling it if necessary. + This method is a no-op if auto_compile is False. + """ + if self.auto_compile and self._dirty: + self.compile() + + # Run in topological order + for node in self._ordered_nodes: + if full_sync: + node.reset_cache() + node.flow() + @classmethod def load(cls, path: Path | str) -> "Pipeline": """Load complete pipeline state""" diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index 55207c2..b6d5f18 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -126,13 +126,11 @@ def __repr__(self): def __str__(self): return f"{self.__class__.__name__}<{self.kernel}>" - @property - def label(self) -> str: - return self._label or self.kernel.label - - @label.setter - def label(self, label: str) -> None: - self._label = label + def computed_label(self) -> str | None: + """ + Return the label of the wrapped kernel. + """ + return self.kernel.label def resolve_input_streams(self, *input_streams) -> Collection[SyncStream]: if input_streams: @@ -207,7 +205,7 @@ def __init__( if _registry is None: _registry = default_registry self.registry = _registry - self.source_info = self.label, self.object_hasher.hash_to_hex(self.kernel) + self.source_info = self.label, str(hash(self.kernel)) self._cache_computed = False From 90b9dada7fac5b76d842bf3a9a449db957341425 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 25 Jun 2025 03:17:42 +0000 Subject: [PATCH 06/57] optim: avoid len call by using list comprehension --- src/orcapod/core/operators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index 654d9d2..b5c4512 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -261,8 +261,9 @@ def forward(self, *streams: SyncStream) -> SyncStream: left_stream, right_stream = streams def generator() -> Iterator[tuple[Tag, Packet]]: - left_stream_buffered = list(left_stream) - right_stream_buffered = list(right_stream) + # using list comprehension rather than list() to avoid call to __len__ which is expensive + left_stream_buffered = [e for e in left_stream] + right_stream_buffered = [e for e in right_stream] for left_tag, left_packet in left_stream_buffered: for right_tag, right_packet in right_stream_buffered: if (joined_tag := join_tags(left_tag, right_tag)) is not None: From 1e61259f890c34aa3a15ee1704dd618d7a7ecd00 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 25 Jun 2025 07:17:37 +0000 Subject: [PATCH 07/57] refactor: place Operator back in base --- src/orcapod/core/operators.py | 85 +++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index b5c4512..53ecacc 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -5,24 +5,21 @@ from orcapod.types import Packet, Tag, TypeSpec from orcapod.hashing import function_content_hash, hash_function -from orcapod.core.base import Kernel, SyncStream +from orcapod.core.base import Kernel, SyncStream, Operator from orcapod.core.streams import SyncStreamFromGenerator from orcapod.utils.stream_utils import ( batch_packet, batch_tags, check_packet_compatibility, + intersection_typespecs, join_tags, - fill_missing, - merge_typespecs, + semijoin_tags, + union_typespecs, + intersection_typespecs, + fill_missing ) -class Operator(Kernel): - """ - A Mapper is an operation that does NOT generate new file content. - It is used to control the flow of data in the pipeline without modifying or creating data content. - """ - class Repeat(Operator): """ @@ -245,8 +242,8 @@ def types( right_tag_types, right_packet_types = right_stream.types(trigger_run=False) # TODO: do error handling when merge fails - joined_tag_types = merge_typespecs(left_tag_types, right_tag_types) - joined_packet_types = merge_typespecs(left_packet_types, right_packet_types) + joined_tag_types = union_typespecs(left_tag_types, right_tag_types) + joined_packet_types = union_typespecs(left_packet_types, right_packet_types) return joined_tag_types, joined_packet_types @@ -377,8 +374,8 @@ def types( ): return super().types(*streams, trigger_run=trigger_run) - joined_tag_types = merge_typespecs(left_tag_types, right_tag_types) - joined_packet_types = merge_typespecs(left_packet_types, right_packet_types) + joined_tag_types = union_typespecs(left_tag_types, right_tag_types) + joined_packet_types = union_typespecs(left_packet_types, right_packet_types) return joined_tag_types, joined_packet_types @@ -599,6 +596,68 @@ def keys( return mapped_tag_keys, packet_keys +class SemiJoin(Operator): + """ + Perform semi-join on the left stream tags with the tags of the right stream + """ + def identity_structure(self, *streams): + # Restrict DOES depend on the order of the streams -- maintain as a tuple + return (self.__class__.__name__,) + streams + + def keys( + self, *streams: SyncStream, trigger_run=False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + For semijoin, output keys and types are identical to left stream + """ + if len(streams) != 2: + raise ValueError("Join operation requires exactly two streams") + + return streams[0].keys(trigger_run=trigger_run) + + def types( + self, *streams: SyncStream, trigger_run=False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + For semijoin, output keys and types are identical to left stream + """ + if len(streams) != 2: + raise ValueError("Join operation requires exactly two streams") + + return streams[0].types(trigger_run=trigger_run) + + def forward(self, *streams: SyncStream) -> SyncStream: + """ + Joins two streams together based on their tags. + The resulting stream will contain all the tags from both streams. + """ + if len(streams) != 2: + raise ValueError("Join operation requires exactly two streams") + + left_stream, right_stream = streams + left_tag_typespec, left_packet_typespec = left_stream.types() + right_tag_typespec, right_packet_typespec = right_stream.types() + + common_tag_typespec = intersection_typespecs(left_tag_typespec, right_tag_typespec) + common_tag_keys = None + if common_tag_typespec is not None: + common_tag_keys = list(common_tag_typespec.keys()) + + def generator() -> Iterator[tuple[Tag, Packet]]: + # using list comprehension rather than list() to avoid call to __len__ which is expensive + left_stream_buffered = [e for e in left_stream] + right_stream_buffered = [e for e in right_stream] + for left_tag, left_packet in left_stream_buffered: + for right_tag, _ in right_stream_buffered: + if semijoin_tags(left_tag, right_tag, common_tag_keys) is not None: + yield left_tag, left_packet + # move onto next entry + break + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + return "SemiJoin()" class Filter(Operator): """ From df581342eb82a2f7a44ae716a0da5882c50d265c Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 25 Jun 2025 07:18:29 +0000 Subject: [PATCH 08/57] refactor: place operator in base and add additional operator methods to sync stream --- src/orcapod/core/base.py | 77 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 9aa8b4d..0a99a8a 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -8,7 +8,7 @@ from orcapod.hashing import HashableMixin, ObjectHasher from orcapod.hashing import get_default_object_hasher -from orcapod.hashing import ContentHashableBase +from orcapod.hashing import ContentIdentifiableBase from orcapod.types import Packet, Tag, TypeSpec from orcapod.utils.stream_utils import get_typespec @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class Kernel(ABC, ContentHashableBase): +class Kernel(ABC, ContentIdentifiableBase): """ Kernel defines the fundamental unit of computation that can be performed on zero, one or more streams of data. It is the base class for all computations and transformations that can be performed on a collection of streams @@ -36,7 +36,6 @@ def __init__(self, label: str | None = None, skip_tracking: bool = False, **kwar self._skip_tracking = skip_tracking - def pre_forward_hook( self, *streams: "SyncStream", **kwargs ) -> tuple["SyncStream", ...]: @@ -220,7 +219,7 @@ def record(self, invocation: "Invocation") -> None: ... # This is NOT an abstract class, but rather a concrete class that # represents an invocation of a kernel on a collection of streams. -class Invocation(ContentHashableBase): +class Invocation(ContentIdentifiableBase): """ This class represents an invocation of a kernel on a collection of streams. It contains the kernel and the streams that were used in the invocation. @@ -288,7 +287,7 @@ def identity_structure(self) -> int: return self.kernel.identity_structure(*self.streams) -class Stream(ABC, ContentHashableBase): +class Stream(ABC, ContentIdentifiableBase): """ A stream is a collection of tagged-packets that are generated by an operation. The stream is iterable and can be used to access the packets in the stream. @@ -434,6 +433,64 @@ def __len__(self) -> int: """ return sum(1 for _ in self) + def join(self, other: "SyncStream", label:str|None=None) -> "SyncStream": + """ + Returns a new stream that is the result of joining with the other stream. + The join is performed on the tags of the packets in the streams. + """ + from .operators import Join + + if not isinstance(other, SyncStream): + raise TypeError("other must be a SyncStream") + return Join(label=label)(self, other) + + def semijoin(self, other: "SyncStream", label: str | None = None) -> "SyncStream": + """ + Returns a new stream that is the result of semijoining with the other stream. + The semijoin is performed on the tags of the packets in the streams. + """ + from .operators import SemiJoin + + if not isinstance(other, SyncStream): + raise TypeError("other must be a SyncStream") + return SemiJoin(label=label)(self, other) + + def map(self, packet_map: dict | None = None, tag_map: dict | None = None, drop_unmapped:bool=True) -> "SyncStream": + """ + Returns a new stream that is the result of mapping the packets and tags in the stream. + The mapping is applied to each packet in the stream and the resulting packets + are returned in a new stream. + If packet_map is None, no mapping is applied to the packets. + If tag_map is None, no mapping is applied to the tags. + """ + from .operators import MapTags, MapPackets + output = self + if packet_map is not None: + output = MapPackets(packet_map, drop_unmapped=drop_unmapped)(output) + if tag_map is not None: + output = MapTags(tag_map, drop_unmapped=drop_unmapped)(output) + + return output + + def apply(self, transformer: 'dict | Operator') -> "SyncStream": + """ + Returns a new stream that is the result of applying the mapping to the stream. + The mapping is applied to each packet in the stream and the resulting packets + are returned in a new stream. + """ + from .operators import MapPackets + + if isinstance(transformer, dict): + return MapPackets(transformer)(self) + elif isinstance(transformer, Operator): + # If the transformer is an Operator, we can apply it directly + return transformer(self) + + # Otherwise, do not know how to handle the transformer + raise TypeError( + "transformer must be a dictionary or an operator" + ) + def __rshift__( self, transformer: dict | Callable[["SyncStream"], "SyncStream"] ) -> "SyncStream": @@ -442,7 +499,6 @@ def __rshift__( The mapping is applied to each packet in the stream and the resulting packets are returned in a new stream. """ - # TODO: remove just in time import from .operators import MapPackets if isinstance(transformer, dict): @@ -459,13 +515,13 @@ def __mul__(self, other: "SyncStream") -> "SyncStream": """ Returns a new stream that is the result joining with the other stream """ - # TODO: remove just in time import from .operators import Join if not isinstance(other, SyncStream): raise TypeError("other must be a SyncStream") return Join()(self, other) + def claims_unique_tags(self, *, trigger_run=False) -> bool | None: """ For synchronous streams, if the stream is generated by an operation, the invocation @@ -490,6 +546,13 @@ def claims_unique_tags(self, *, trigger_run=False) -> bool | None: return True +class Operator(Kernel): + """ + A Mapper is an operation that does NOT generate new file content. + It is used to control the flow of data in the pipeline without modifying or creating data content. + """ + + class Source(Kernel, SyncStream): """ A base class for all sources in the system. A source can be seen as a special From 6e4d4bd876ceb38dc931644e53e4d98e5090e888 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 25 Jun 2025 07:19:12 +0000 Subject: [PATCH 09/57] wip: change to content identifable base --- src/orcapod/core/streams.py | 1 + src/orcapod/hashing/__init__.py | 4 +- src/orcapod/hashing/content_hashable.py | 17 ++-- src/orcapod/hashing/core.py | 4 +- src/orcapod/pipeline/wrappers.py | 112 +++++++++++++++++------- src/orcapod/utils/stream_utils.py | 35 +++++++- 6 files changed, 130 insertions(+), 43 deletions(-) diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index f5c5e60..a1e9620 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -5,6 +5,7 @@ from copy import copy + class SyncStreamFromLists(SyncStream): def __init__( self, diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index f95b0f7..d3d83e9 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -24,7 +24,7 @@ FunctionInfoExtractor, CompositeFileHasher, ) -from .content_hashable import ContentHashableBase +from .content_hashable import ContentIdentifiableBase __all__ = [ "FileHasher", @@ -47,5 +47,5 @@ "get_default_composite_file_hasher", "get_default_object_hasher", "get_default_arrow_hasher", - "ContentHashableBase", + "ContentIdentifiableBase", ] diff --git a/src/orcapod/hashing/content_hashable.py b/src/orcapod/hashing/content_hashable.py index 33fd1bb..61eb0e5 100644 --- a/src/orcapod/hashing/content_hashable.py +++ b/src/orcapod/hashing/content_hashable.py @@ -4,15 +4,15 @@ from typing import Any -class ContentHashableBase: - def __init__(self, object_hasher: ObjectHasher | None = None, label: str | None = None) -> None: +class ContentIdentifiableBase: + def __init__(self, identity_structure_hasher: ObjectHasher | None = None, label: str | None = None) -> None: """ Initialize the ContentHashable with an optional ObjectHasher. Args: - object_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. + identity_structure_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. """ - self.object_hasher = object_hasher or get_default_object_hasher() + self.identity_structure_hasher = identity_structure_hasher or get_default_object_hasher() self._label = label @property @@ -36,10 +36,13 @@ def label(self, label: str | None) -> None: self._label = label def computed_label(self) -> str|None: + """ + Compute a label for this object based on its content. If label is not explicitly set for this object + and computed_label returns a valid value, it will be used as label of this object. + """ return None - def identity_structure(self) -> Any: """ Return a structure that represents the identity of this object. @@ -68,7 +71,7 @@ def __hash__(self) -> int: # If no identity structure is provided, use the default hash return super().__hash__() - return self.object_hasher.hash_to_int(structure) + return self.identity_structure_hasher.hash_to_int(structure) def __eq__(self, other: object) -> bool: """ @@ -80,7 +83,7 @@ def __eq__(self, other: object) -> bool: Returns: bool: True if both objects have the same identity structure, False otherwise. """ - if not isinstance(other, ContentHashableBase): + if not isinstance(other, ContentIdentifiableBase): return NotImplemented return self.identity_structure() == other.identity_structure() \ No newline at end of file diff --git a/src/orcapod/hashing/core.py b/src/orcapod/hashing/core.py index d37ade3..08fd812 100644 --- a/src/orcapod/hashing/core.py +++ b/src/orcapod/hashing/core.py @@ -437,8 +437,8 @@ def process_structure( logger.debug(f"Processing HashableMixin instance of type {type(obj).__name__}") return obj.content_hash() - from .content_hashable import ContentHashableBase - if isinstance(obj, ContentHashableBase): + from .content_hashable import ContentIdentifiableBase + if isinstance(obj, ContentIdentifiableBase): logger.debug(f"Processing ContentHashableBase instance of type {type(obj).__name__}") return process_structure(obj.identity_structure(), visited, function_info_extractor) diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index b6d5f18..a77d67a 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -11,7 +11,7 @@ import pyarrow as pa import polars as pl from orcapod.core.streams import SyncStreamFromGenerator -from orcapod.utils.stream_utils import get_typespec, merge_typespecs +from orcapod.utils.stream_utils import get_typespec, union_typespecs import logging logger = logging.getLogger(__name__) @@ -184,31 +184,45 @@ def __init__( kernel: Kernel, input_streams: Collection[SyncStream], output_store: ArrowDataStore, - _object_hasher: ObjectHasher | None = None, - _arrow_hasher: ArrowHasher | None = None, - _registry: TypeRegistry | None = None, + kernel_hasher: ObjectHasher | None = None, + arrow_packet_hasher: ArrowHasher | None = None, + packet_type_registry: TypeRegistry | None = None, **kwargs, ) -> None: super().__init__(kernel, input_streams,**kwargs) self.output_store = output_store - self.tag_keys, self.packet_keys = self.keys(trigger_run=False) - self.output_converter = None # These are configurable but are not expected to be modified except for special circumstances - if _object_hasher is None: - _object_hasher = get_default_object_hasher() - self.object_hasher = _object_hasher - if _arrow_hasher is None: - _arrow_hasher = get_default_arrow_hasher() - self.arrow_hasher = _arrow_hasher - if _registry is None: - _registry = default_registry - self.registry = _registry - self.source_info = self.label, str(hash(self.kernel)) + if kernel_hasher is None: + kernel_hasher = get_default_object_hasher() + self._kernel_hasher = kernel_hasher + if arrow_packet_hasher is None: + arrow_packet_hasher = get_default_arrow_hasher() + self._arrow_packet_hasher = arrow_packet_hasher + if packet_type_registry is None: + packet_type_registry = default_registry + self._packet_type_registry = packet_type_registry + + + self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) + self.tag_keys, self.packet_keys = self.keys(trigger_run=False) + self.output_converter = None self._cache_computed = False + @property + def kernel_hasher(self) -> ObjectHasher: + return self._kernel_hasher + + @kernel_hasher.setter + def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): + if kernel_hasher is None: + kernel_hasher = get_default_object_hasher() + self._kernel_hasher = kernel_hasher + # hasher changed -- trigger recomputation of properties that depend on kernel hasher + self.update_cached_values() + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: @@ -224,7 +238,7 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: tag_type, packet_type = output_stream.types(trigger_run=False) if tag_type is not None and packet_type is not None: - joined_type = merge_typespecs(tag_type, packet_type) + joined_type = union_typespecs(tag_type, packet_type) assert joined_type is not None, "Joined typespec should not be None" self.output_converter = PacketConverter(joined_type, registry=self.registry) @@ -324,9 +338,9 @@ def __init__( skip_memoization: bool = False, skip_tag_record: bool = False, error_handling: Literal["raise", "ignore", "warn"] = "raise", - _object_hasher: ObjectHasher | None = None, - _arrow_hasher: ArrowHasher | None = None, - _registry: TypeRegistry | None = None, + object_hasher: ObjectHasher | None = None, + arrow_hasher: ArrowHasher | None = None, + registry: TypeRegistry | None = None, **kwargs, ) -> None: super().__init__( @@ -344,28 +358,64 @@ def __init__( self.skip_tag_record = skip_tag_record # These are configurable but are not expected to be modified except for special circumstances + # Here I'm assigning to the hidden properties directly to avoid triggering setters if _object_hasher is None: _object_hasher = get_default_object_hasher() - self.object_hasher = _object_hasher + self._object_hasher = _object_hasher if _arrow_hasher is None: _arrow_hasher = get_default_arrow_hasher() - self.arrow_hasher = _arrow_hasher + self._arrow_hasher = _arrow_hasher if _registry is None: _registry = default_registry - self.registry = _registry + self._registry = _registry - # TODO: consider making this dynamic - self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) - self.tag_keys, self.output_keys = self.keys(trigger_run=False) + # compute and cache properties and converters for efficiency + self.update_cached_values() + self._cache_computed = False - # prepare packet converters - input_typespec, output_typespec = self.function_pod.get_function_typespecs() + @property + def object_hasher(self) -> ObjectHasher: + return self._object_hasher - self.input_converter = PacketConverter(input_typespec, self.registry) - self.output_converter = PacketConverter(output_typespec, self.registry) + @object_hasher.setter + def object_hasher(self, object_hasher:ObjectHasher | None = None): + if object_hasher is None: + object_hasher = get_default_object_hasher() + self._object_hasher = object_hasher + # hasher changed -- trigger recomputation of properties that depend on object hasher + self.update_cached_values() - self._cache_computed = False + @property + def arrow_hasher(self) -> ArrowHasher: + return self._arrow_hasher + + @arrow_hasher.setter + def arrow_hasher(self, arrow_hasher:ArrowHasher | None = None): + if arrow_hasher is None: + arrow_hasher = get_default_arrow_hasher() + self._arrow_hasher = arrow_hasher + # hasher changed -- trigger recomputation of properties that depend on arrow hasher + self.update_cached_values() + + @property + def registry(self) -> TypeRegistry: + return self._registry + + @registry.setter + def registry(self, registry: TypeRegistry | None = None): + if registry is None: + registry = default_registry + self._registry = registry + # registry changed -- trigger recomputation of properties that depend on registry + self.update_cached_values() + + def update_cached_values(self) -> None: + self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) + self.tag_keys, self.output_keys = self.keys(trigger_run=False) + self.input_typespec, self.output_typespec = self.function_pod.get_function_typespecs() + self.input_converter = PacketConverter(self.input_typespec, self.registry) + self.output_converter = PacketConverter(self.output_typespec, self.registry) def reset_cache(self): self._cache_computed = False diff --git a/src/orcapod/utils/stream_utils.py b/src/orcapod/utils/stream_utils.py index 51d46c1..95703c8 100644 --- a/src/orcapod/utils/stream_utils.py +++ b/src/orcapod/utils/stream_utils.py @@ -43,7 +43,7 @@ def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: return merged -def merge_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: +def union_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: if left is None: return right if right is None: @@ -58,6 +58,25 @@ def merge_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | ) return merged +def intersection_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: + """ + Returns the intersection of two TypeSpecs, only returning keys that are present in both. + If a key is present in both TypeSpecs, the type must be the same. + """ + if left is None or right is None: + return None + # Find common keys and ensure types match + common_keys = set(left.keys()).intersection(set(right.keys())) + intersection = {} + for key in common_keys: + try: + intersection[key] = get_compatible_type(left[key], right[key]) + except TypeError: + # If types are not compatible, raise an error + raise TypeError(f"Type conflict for key '{key}': {left[key]} vs {right[key]}") + + return intersection + def common_elements(*values) -> Collection[str]: """ @@ -88,6 +107,20 @@ def join_tags(tag1: Mapping[K, V], tag2: Mapping[K, V]) -> dict[K, V] | None: joined_tag[k] = v return joined_tag +def semijoin_tags(tag1: Mapping[K, V], tag2: Mapping[K, V], target_keys: Collection[K]|None = None) -> dict[K, V] | None: + """ + Semijoin two tags. If the tags have the same key, the value must be the same or None will be returned. If all shared + key's value match, tag1 would be returned + """ + if target_keys is None: + target_keys = set(tag1.keys()).intersection(set(tag2.keys())) + if not target_keys: + return dict(tag1) + + for key in target_keys: + if tag1[key] != tag2[key]: + return None + return dict(tag1) def check_packet_compatibility(packet1: Packet, packet2: Packet) -> bool: """ From 5fb2435453bd66fadfdcd8324c1581421ceef86d Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 26 Jun 2025 18:53:22 +0000 Subject: [PATCH 10/57] style: apply ruff formatting --- src/orcapod/core/streams.py | 19 ++++-- src/orcapod/pipeline/wrappers.py | 114 ++++++++++++++++++------------- 2 files changed, 80 insertions(+), 53 deletions(-) diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index a1e9620..c70b009 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -5,7 +5,6 @@ from copy import copy - class SyncStreamFromLists(SyncStream): def __init__( self, @@ -50,16 +49,21 @@ def keys( super_tag_keys, super_packet_keys = super().keys(trigger_run=trigger_run) tag_keys = tag_keys or super_tag_keys packet_keys = packet_keys or super_packet_keys - + # If the keys are already set, return them return tag_keys, packet_keys - + def types( self, *, trigger_run: bool = False ) -> tuple[TypeSpec | None, TypeSpec | None]: - tag_typespec, packet_typespec = copy(self.tag_typespec), copy(self.packet_typespec) + tag_typespec, packet_typespec = ( + copy(self.tag_typespec), + copy(self.packet_typespec), + ) if tag_typespec is None or packet_typespec is None: - super_tag_typespec, super_packet_typespec = super().types(trigger_run=trigger_run) + super_tag_typespec, super_packet_typespec = super().types( + trigger_run=trigger_run + ) tag_typespec = tag_typespec or super_tag_typespec packet_typespec = packet_typespec or super_packet_typespec @@ -69,7 +73,6 @@ def types( def __iter__(self) -> Iterator[tuple[Tag, Packet]]: yield from self.paired - class SyncStreamFromGenerator(SyncStream): """ @@ -87,9 +90,11 @@ def __init__( self.tag_keys = tag_keys self.packet_keys = packet_keys self.generator_factory = generator_factory + self.check_consistency = False def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - yield from self.generator_factory() + if not self.check_consistency: + yield from self.generator_factory() def keys( self, *, trigger_run: bool = False diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index a77d67a..e953f1f 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -14,8 +14,10 @@ from orcapod.utils.stream_utils import get_typespec, union_typespecs import logging + logger = logging.getLogger(__name__) + def tag_to_arrow_table_with_metadata(tag, metadata: dict | None = None): """ Convert a tag dictionary to PyArrow table with metadata on each column. @@ -49,18 +51,23 @@ def tag_to_arrow_table_with_metadata(tag, metadata: dict | None = None): return table -def get_columns_with_metadata(df: pl.DataFrame, key: str, value: str|None = None) -> list[str]: + +def get_columns_with_metadata( + df: pl.DataFrame, key: str, value: str | None = None +) -> list[str]: """Get column names with specific metadata using list comprehension. If value is given, only - columns matching that specific value for the desginated metadata key will be returned. + columns matching that specific value for the desginated metadata key will be returned. Otherwise, all columns that contains the key as metadata will be returned regardless of the value""" return [ - col_name for col_name, dtype in df.schema.items() - if hasattr(dtype, "metadata") and (value is None or getattr(dtype, "metadata") == value) + col_name + for col_name, dtype in df.schema.items() + if hasattr(dtype, "metadata") + and (value is None or getattr(dtype, "metadata") == value) ] class PolarsSource(Source): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]|None = None): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str] | None = None): self.df = df self.tag_keys = tag_keys @@ -74,7 +81,7 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: class PolarsStream(SyncStream): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]|None = None): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str] | None = None): self.df = df if tag_keys is None: # extract tag_keys by picking columns with metadata source=tag @@ -87,8 +94,15 @@ def __iter__(self) -> Iterator[tuple[Tag, Packet]]: packet = {key: val for key, val in row.items() if key not in self.tag_keys} yield tag, packet + class EmptyStream(SyncStream): - def __init__(self, tag_keys: Collection[str]|None = None, packet_keys: Collection[str]|None = None, tag_typespec: TypeSpec | None = None, packet_typespec:TypeSpec|None = None): + def __init__( + self, + tag_keys: Collection[str] | None = None, + packet_keys: Collection[str] | None = None, + tag_typespec: TypeSpec | None = None, + packet_typespec: TypeSpec | None = None, + ): if tag_keys is None and tag_typespec is not None: tag_keys = tag_typespec.keys() self.tag_keys = list(tag_keys) if tag_keys else [] @@ -100,10 +114,14 @@ def __init__(self, tag_keys: Collection[str]|None = None, packet_keys: Collectio self.tag_typespec = tag_typespec self.packet_typespec = packet_typespec - def keys(self, *streams: SyncStream, trigger_run: bool = False) -> tuple[Collection[str] | None, Collection[str] | None]: + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: return self.tag_keys, self.packet_keys - def types(self, *streams: SyncStream, trigger_run: bool = False) -> tuple[TypeSpec | None, TypeSpec | None]: + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: return self.tag_typespec, self.packet_typespec def __iter__(self) -> Iterator[tuple[Tag, Packet]]: @@ -111,15 +129,14 @@ def __iter__(self) -> Iterator[tuple[Tag, Packet]]: return iter([]) - - class KernelInvocationWrapper(Kernel): - def __init__(self, kernel: Kernel, input_streams: Collection[SyncStream], **kwargs) -> None: + def __init__( + self, kernel: Kernel, input_streams: Collection[SyncStream], **kwargs + ) -> None: super().__init__(**kwargs) self.kernel = kernel self.input_streams = list(input_streams) - def __repr__(self): return f"{self.__class__.__name__}<{self.kernel!r}>" @@ -163,7 +180,7 @@ def claims_unique_tags( ) -> bool | None: resolved_streams = self.resolve_input_streams(*streams) return self.kernel.claims_unique_tags( - *resolved_streams, trigger_run=trigger_run + *resolved_streams, trigger_run=trigger_run ) @@ -189,7 +206,7 @@ def __init__( packet_type_registry: TypeRegistry | None = None, **kwargs, ) -> None: - super().__init__(kernel, input_streams,**kwargs) + super().__init__(kernel, input_streams, **kwargs) self.output_store = output_store @@ -204,15 +221,24 @@ def __init__( packet_type_registry = default_registry self._packet_type_registry = packet_type_registry - self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) self.output_converter = None self._cache_computed = False + @property + def arrow_hasher(self): + return self._arrow_packet_hasher + + @property + def registry(self): + return self._packet_type_registry + @property def kernel_hasher(self) -> ObjectHasher: + if self._kernel_hasher is None: + return get_default_object_hasher() return self._kernel_hasher @kernel_hasher.setter @@ -223,6 +249,10 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): # hasher changed -- trigger recomputation of properties that depend on kernel hasher self.update_cached_values() + def update_cached_values(self): + self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) + self.tag_keys, self.packet_keys = self.keys(trigger_run=False) + self.output_converter = None def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: @@ -233,7 +263,7 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.packet_keys) resolved_streams = self.resolve_input_streams(*streams) - + output_stream = self.kernel.forward(*resolved_streams, **kwargs) tag_type, packet_type = output_stream.types(trigger_run=False) @@ -279,26 +309,25 @@ def df(self) -> pl.DataFrame | None: return None return lazy_df.collect() - def reset_cache(self): self._cache_computed = False - class FunctionPodInvocationWrapper(KernelInvocationWrapper, Pod): """ Convenience class to wrap a function pod, providing default pass-through implementations """ - def __init__(self, function_pod: FunctionPod, input_streams: Collection[SyncStream], **kwargs): + def __init__( + self, function_pod: FunctionPod, input_streams: Collection[SyncStream], **kwargs + ): # note that this would be an alias to the self.kernel but here explicitly taken as function_pod # for better type hints # MRO will be KernelInvocationWrapper -> Pod -> Kernel super().__init__(function_pod, input_streams, **kwargs) self.function_pod = function_pod - def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: resolved_streams = self.resolve_input_streams(*streams) return super().forward(*resolved_streams, **kwargs) @@ -306,7 +335,6 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: return self.function_pod.call(tag, packet) - # =============pass through methods/properties to the underlying function pod============= def set_active(self, active=True): @@ -322,10 +350,6 @@ def is_active(self) -> bool: return self.function_pod.is_active() - - - - class CachedFunctionPodWrapper(FunctionPodInvocationWrapper, Source): def __init__( self, @@ -359,16 +383,15 @@ def __init__( # These are configurable but are not expected to be modified except for special circumstances # Here I'm assigning to the hidden properties directly to avoid triggering setters - if _object_hasher is None: - _object_hasher = get_default_object_hasher() - self._object_hasher = _object_hasher - if _arrow_hasher is None: - _arrow_hasher = get_default_arrow_hasher() - self._arrow_hasher = _arrow_hasher - if _registry is None: - _registry = default_registry - self._registry = _registry - + if object_hasher is None: + object_hasher = get_default_object_hasher() + self._object_hasher = object_hasher + if arrow_hasher is None: + arrow_hasher = get_default_arrow_hasher() + self._arrow_hasher = arrow_hasher + if registry is None: + registry = default_registry + self._registry = registry # compute and cache properties and converters for efficiency self.update_cached_values() @@ -379,7 +402,7 @@ def object_hasher(self) -> ObjectHasher: return self._object_hasher @object_hasher.setter - def object_hasher(self, object_hasher:ObjectHasher | None = None): + def object_hasher(self, object_hasher: ObjectHasher | None = None): if object_hasher is None: object_hasher = get_default_object_hasher() self._object_hasher = object_hasher @@ -391,7 +414,7 @@ def arrow_hasher(self) -> ArrowHasher: return self._arrow_hasher @arrow_hasher.setter - def arrow_hasher(self, arrow_hasher:ArrowHasher | None = None): + def arrow_hasher(self, arrow_hasher: ArrowHasher | None = None): if arrow_hasher is None: arrow_hasher = get_default_arrow_hasher() self._arrow_hasher = arrow_hasher @@ -413,7 +436,9 @@ def registry(self, registry: TypeRegistry | None = None): def update_cached_values(self) -> None: self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) self.tag_keys, self.output_keys = self.keys(trigger_run=False) - self.input_typespec, self.output_typespec = self.function_pod.get_function_typespecs() + self.input_typespec, self.output_typespec = ( + self.function_pod.get_function_typespecs() + ) self.input_converter = PacketConverter(self.input_typespec, self.registry) self.output_converter = PacketConverter(self.output_typespec, self.registry) @@ -435,14 +460,11 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.output_keys) logger.info(f"Computing and caching outputs for {self}") - return super().forward(*streams, **kwargs) - + return super().forward(*streams, **kwargs) def get_packet_key(self, packet: Packet) -> str: # TODO: reconsider the logic around input/output converter -- who should own this? - return self.arrow_hasher.hash_table( - self.input_converter.to_arrow_table(packet) - ) + return self.arrow_hasher.hash_table(self.input_converter.to_arrow_table(packet)) @property def source_info(self): @@ -701,15 +723,15 @@ def __init__(self, kernel: Kernel, input_nodes: Collection["Node"], **kwargs): def reset_cache(self) -> None: ... - class KernelNode(Node, CachedKernelWrapper): """ A node that wraps a Kernel and provides a Node interface. This is useful for creating nodes in a pipeline that can be executed. """ + class FunctionPodNode(Node, CachedFunctionPodWrapper): """ A node that wraps a FunctionPod and provides a Node interface. This is useful for creating nodes in a pipeline that can be executed. - """ \ No newline at end of file + """ From 09f59cbcfcabb5e24e69bc64e5bd97a6340bd582 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 26 Jun 2025 19:07:10 +0000 Subject: [PATCH 11/57] refactor: clean up test of name orcabridge --- .devcontainer/Dockerfile | 2 +- src/orcapod/core/base.py | 30 +++++++++------- src/orcapod/pipeline/pipeline.py | 35 +++++++++++-------- tests/test_hashing/generate_file_hashes.py | 3 +- .../generate_pathset_packet_hashes.py | 3 +- .../test_basic_composite_hasher.py | 1 - tests/test_hashing/test_cached_file_hasher.py | 1 - tests/test_hashing/test_file_hashes.py | 3 +- tests/test_hashing/test_packet_hasher.py | 1 - tests/test_hashing/test_pathset_and_packet.py | 3 +- .../test_pathset_packet_hashes.py | 3 +- tests/test_store/conftest.py | 1 - tests/test_store/test_dir_data_store.py | 1 - tests/test_store/test_integration.py | 1 - tests/test_store/test_noop_data_store.py | 1 - tests/test_store/test_transfer_data_store.py | 1 - tests/test_types/__init__.py | 2 +- tests/test_types/test_inference/__init__.py | 2 +- 18 files changed, 46 insertions(+), 48 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index c3b180d..33e1e11 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -26,7 +26,7 @@ RUN \ USER vscode ENV PATH=/home/vscode/.local/bin:$PATH WORKDIR /home/vscode -COPY --chown=vscode:nogroup src/orcabridge/requirements.txt /tmp/requirements.txt +COPY --chown=vscode:nogroup src/orcapod/requirements.txt /tmp/requirements.txt RUN \ # python setup curl -LsSf https://astral.sh/uv/install.sh | sh && \ diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 0a99a8a..fc18b48 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -1,4 +1,4 @@ -# Collection of base classes for operations and streams in the orcabridge framework. +# Collection of base classes for operations and streams in the orcapod framework. import threading from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterator @@ -30,12 +30,13 @@ class Kernel(ABC, ContentIdentifiableBase): for computational graph tracking. """ - def __init__(self, label: str | None = None, skip_tracking: bool = False, **kwargs) -> None: + def __init__( + self, label: str | None = None, skip_tracking: bool = False, **kwargs + ) -> None: super().__init__(**kwargs) self._label = label self._skip_tracking = skip_tracking - def pre_forward_hook( self, *streams: "SyncStream", **kwargs ) -> tuple["SyncStream", ...]: @@ -54,8 +55,9 @@ def post_forward_hook(self, output_stream: "SyncStream", **kwargs) -> "SyncStrea """ return output_stream - - def __call__(self, *streams: "SyncStream", label:str|None = None, **kwargs) -> "SyncStream": + def __call__( + self, *streams: "SyncStream", label: str | None = None, **kwargs + ) -> "SyncStream": if label is not None: self.label = label # Special handling of Source: trigger call on source if passed as stream @@ -305,7 +307,6 @@ def computed_label(self) -> str | None: # use the invocation operation label return self.invocation.kernel.label return None - @property def invocation(self) -> Invocation | None: @@ -433,7 +434,7 @@ def __len__(self) -> int: """ return sum(1 for _ in self) - def join(self, other: "SyncStream", label:str|None=None) -> "SyncStream": + def join(self, other: "SyncStream", label: str | None = None) -> "SyncStream": """ Returns a new stream that is the result of joining with the other stream. The join is performed on the tags of the packets in the streams. @@ -455,7 +456,12 @@ def semijoin(self, other: "SyncStream", label: str | None = None) -> "SyncStream raise TypeError("other must be a SyncStream") return SemiJoin(label=label)(self, other) - def map(self, packet_map: dict | None = None, tag_map: dict | None = None, drop_unmapped:bool=True) -> "SyncStream": + def map( + self, + packet_map: dict | None = None, + tag_map: dict | None = None, + drop_unmapped: bool = True, + ) -> "SyncStream": """ Returns a new stream that is the result of mapping the packets and tags in the stream. The mapping is applied to each packet in the stream and the resulting packets @@ -464,6 +470,7 @@ def map(self, packet_map: dict | None = None, tag_map: dict | None = None, drop_ If tag_map is None, no mapping is applied to the tags. """ from .operators import MapTags, MapPackets + output = self if packet_map is not None: output = MapPackets(packet_map, drop_unmapped=drop_unmapped)(output) @@ -472,7 +479,7 @@ def map(self, packet_map: dict | None = None, tag_map: dict | None = None, drop_ return output - def apply(self, transformer: 'dict | Operator') -> "SyncStream": + def apply(self, transformer: "dict | Operator") -> "SyncStream": """ Returns a new stream that is the result of applying the mapping to the stream. The mapping is applied to each packet in the stream and the resulting packets @@ -487,9 +494,7 @@ def apply(self, transformer: 'dict | Operator') -> "SyncStream": return transformer(self) # Otherwise, do not know how to handle the transformer - raise TypeError( - "transformer must be a dictionary or an operator" - ) + raise TypeError("transformer must be a dictionary or an operator") def __rshift__( self, transformer: dict | Callable[["SyncStream"], "SyncStream"] @@ -521,7 +526,6 @@ def __mul__(self, other: "SyncStream") -> "SyncStream": raise TypeError("other must be a SyncStream") return Join()(self, other) - def claims_unique_tags(self, *, trigger_run=False) -> bool | None: """ For synchronous streams, if the stream is generated by an operation, the invocation diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 7edd03e..74eb998 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -31,7 +31,13 @@ class Pipeline(GraphTracker): Replaces the old Tracker with better persistence and view capabilities. """ - def __init__(self, name: str, results_store: ArrowDataStore, pipeline_store: ArrowDataStore, auto_compile:bool=True) -> None: + def __init__( + self, + name: str, + results_store: ArrowDataStore, + pipeline_store: ArrowDataStore, + auto_compile: bool = True, + ) -> None: super().__init__() self.name = name or f"pipeline_{id(self)}" self.results_store = results_store @@ -55,7 +61,7 @@ def save(self, path: Path | str) -> None: "metadata": { "created_at": time.time(), "python_version": sys.version_info[:2], - "orcabridge_version": "0.1.0", # You can make this dynamic + "orcapod_version": "0.1.0", # TODO: make this dynamic }, } @@ -79,22 +85,26 @@ def record(self, invocation: Invocation) -> None: super().record(invocation) self._dirty = True - def wrap_invocation( - self, kernel: Kernel, input_nodes: Collection[Node] - ) -> Node: + def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node: if isinstance(kernel, FunctionPod): - return FunctionPodNode(kernel, input_nodes, output_store=self.results_store, tag_store=self.pipeline_store) + return FunctionPodNode( + kernel, + input_nodes, + output_store=self.results_store, + tag_store=self.pipeline_store, + ) return KernelNode(kernel, input_nodes, output_store=self.pipeline_store) def compile(self): import networkx as nx + G = self.generate_graph() # Proposed labels for each Kernel in the graph # If name collides, unique name is generated by appending an index proposed_labels = defaultdict(list) node_lut = {} - edge_lut : dict[SyncStream, Node]= {} + edge_lut: dict[SyncStream, Node] = {} ordered_nodes = [] for invocation in nx.topological_sort(G): # map streams to the new streams based on Nodes @@ -109,7 +119,7 @@ def compile(self): for edge in G.out_edges(invocation): edge_lut[G.edges[edge]["stream"]] = new_node - + self._ordered_nodes = ordered_nodes # resolve duplicates in proposed_labels @@ -134,18 +144,17 @@ def __exit__(self, exc_type, exc_val, ext_tb): if self.auto_compile: self.compile() - def __getattr__(self, item: str) -> Any: """Allow direct access to pipeline attributes""" if item in self.labels_to_nodes: return self.labels_to_nodes[item] raise AttributeError(f"Pipeline has no attribute '{item}'") - + def __dir__(self): # Include both regular attributes and dynamic ones return list(super().__dir__()) + list(self.labels_to_nodes.keys()) - def run(self, full_sync:bool=False) -> None: + def run(self, full_sync: bool = False) -> None: """ Run the pipeline, compiling it if necessary. This method is a no-op if auto_compile is False. @@ -158,7 +167,7 @@ def run(self, full_sync:bool=False) -> None: if full_sync: node.reset_cache() node.flow() - + @classmethod def load(cls, path: Path | str) -> "Pipeline": """Load complete pipeline state""" @@ -196,5 +205,3 @@ def _validate_serializable(self) -> None: + "\n".join(f" - {issue}" for issue in issues) + "\n\nOnly named functions are supported for serialization." ) - - diff --git a/tests/test_hashing/generate_file_hashes.py b/tests/test_hashing/generate_file_hashes.py index 1002b7f..0beb66c 100644 --- a/tests/test_hashing/generate_file_hashes.py +++ b/tests/test_hashing/generate_file_hashes.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/generate_file_hashes.py """ Generate sample files with random content and record their hashes. @@ -14,7 +13,7 @@ from datetime import datetime from pathlib import Path -# Add the parent directory to the path to import orcabridge +# Add the parent directory to the path to import orcapod sys.path.append(str(Path(__file__).parent.parent.parent)) from orcapod.hashing import hash_file diff --git a/tests/test_hashing/generate_pathset_packet_hashes.py b/tests/test_hashing/generate_pathset_packet_hashes.py index 61a36eb..edd804d 100644 --- a/tests/test_hashing/generate_pathset_packet_hashes.py +++ b/tests/test_hashing/generate_pathset_packet_hashes.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/generate_pathset_packet_hashes.py """ Generate sample pathsets and packets and record their hashes. @@ -11,7 +10,7 @@ import sys from pathlib import Path -# Add the parent directory to the path to import orcabridge +# Add the parent directory to the path to import orcapod sys.path.append(str(Path(__file__).parent.parent.parent)) from orcapod.hashing import hash_packet, hash_pathset diff --git a/tests/test_hashing/test_basic_composite_hasher.py b/tests/test_hashing/test_basic_composite_hasher.py index d2c5361..2ef9cf6 100644 --- a/tests/test_hashing/test_basic_composite_hasher.py +++ b/tests/test_hashing/test_basic_composite_hasher.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_default_file_hasher.py """ Test DefaultFileHasher functionality. diff --git a/tests/test_hashing/test_cached_file_hasher.py b/tests/test_hashing/test_cached_file_hasher.py index 3307628..42c9380 100644 --- a/tests/test_hashing/test_cached_file_hasher.py +++ b/tests/test_hashing/test_cached_file_hasher.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_cached_file_hasher.py """Tests for CachedFileHasher implementation.""" import json diff --git a/tests/test_hashing/test_file_hashes.py b/tests/test_hashing/test_file_hashes.py index 66ed987..1de0716 100644 --- a/tests/test_hashing/test_file_hashes.py +++ b/tests/test_hashing/test_file_hashes.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_file_hashes.py """ Test file hash consistency. @@ -12,7 +11,7 @@ import pytest -# Add the parent directory to the path to import orcabridge +# Add the parent directory to the path to import orcapod from orcapod.hashing import hash_file diff --git a/tests/test_hashing/test_packet_hasher.py b/tests/test_hashing/test_packet_hasher.py index f9d519d..69b89d0 100644 --- a/tests/test_hashing/test_packet_hasher.py +++ b/tests/test_hashing/test_packet_hasher.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_packet_hasher.py """Tests for the PacketHasher protocol implementation.""" import pytest diff --git a/tests/test_hashing/test_pathset_and_packet.py b/tests/test_hashing/test_pathset_and_packet.py index 6b7eb6f..fc00b29 100644 --- a/tests/test_hashing/test_pathset_and_packet.py +++ b/tests/test_hashing/test_pathset_and_packet.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_pathset_and_packet.py """ -Test the hash_pathset and hash_packet functions from orcabridge.hashing. +Test the hash_pathset and hash_packet functions from orcapod.hashing. This module contains tests to verify the correct behavior of hash_pathset and hash_packet functions with various input types and configurations. diff --git a/tests/test_hashing/test_pathset_packet_hashes.py b/tests/test_hashing/test_pathset_packet_hashes.py index 49e2d0c..7745881 100644 --- a/tests/test_hashing/test_pathset_packet_hashes.py +++ b/tests/test_hashing/test_pathset_packet_hashes.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_pathset_packet_hashes.py """ Test pathset and packet hash consistency. @@ -12,7 +11,7 @@ import pytest -# Add the parent directory to the path to import orcabridge +# Add the parent directory to the path to import orcapod from orcapod.hashing import hash_packet, hash_pathset diff --git a/tests/test_store/conftest.py b/tests/test_store/conftest.py index 77ca9f9..6b8aa6f 100644 --- a/tests/test_store/conftest.py +++ b/tests/test_store/conftest.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_store/conftest.py """Common test fixtures for store tests.""" import shutil diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index c07f141..16d82d0 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_store/test_dir_data_store.py """Tests for DirDataStore.""" import json diff --git a/tests/test_store/test_integration.py b/tests/test_store/test_integration.py index 023e6e6..fc26022 100644 --- a/tests/test_store/test_integration.py +++ b/tests/test_store/test_integration.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_store/test_integration.py """Integration tests for the store module.""" import os diff --git a/tests/test_store/test_noop_data_store.py b/tests/test_store/test_noop_data_store.py index 0da82c7..42606b8 100644 --- a/tests/test_store/test_noop_data_store.py +++ b/tests/test_store/test_noop_data_store.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_store/test_noop_data_store.py """Tests for NoOpDataStore.""" import pytest diff --git a/tests/test_store/test_transfer_data_store.py b/tests/test_store/test_transfer_data_store.py index 85d0a87..1e8d178 100644 --- a/tests/test_store/test_transfer_data_store.py +++ b/tests/test_store/test_transfer_data_store.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_store/test_transfer_data_store.py """Tests for TransferDataStore.""" import json diff --git a/tests/test_types/__init__.py b/tests/test_types/__init__.py index aa691b1..2be2a50 100644 --- a/tests/test_types/__init__.py +++ b/tests/test_types/__init__.py @@ -1 +1 @@ -# Test package for orcabridge types module +# Test package for orcapod types module diff --git a/tests/test_types/test_inference/__init__.py b/tests/test_types/test_inference/__init__.py index 45e6baf..ae4cff0 100644 --- a/tests/test_types/test_inference/__init__.py +++ b/tests/test_types/test_inference/__init__.py @@ -1 +1 @@ -# Test package for orcabridge types inference module +# Test package for orcapod types inference module From c5fcb3d33a580262f30e56ec4671c70c00c6961e Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 05:19:19 +0000 Subject: [PATCH 12/57] test: remove filepath specification --- tests/test_hashing/test_hasher_parity.py | 1 - tests/test_store/test_dir_data_store.py | 2 +- tests/test_store/test_integration.py | 2 +- tests/test_store/test_noop_data_store.py | 4 +- tests/test_store/test_transfer_data_store.py | 4 +- .../test_extract_function_data_types.py | 60 +++++++++---------- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/tests/test_hashing/test_hasher_parity.py b/tests/test_hashing/test_hasher_parity.py index fb83afb..64a6004 100644 --- a/tests/test_hashing/test_hasher_parity.py +++ b/tests/test_hashing/test_hasher_parity.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_hasher_parity.py """ Test parity between DefaultFileHasher and core hashing functions. diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index 16d82d0..32d8618 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -13,7 +13,7 @@ PacketHasher, PathSetHasher, ) -from orcapod.store.core import DirDataStore +from orcapod.store.dict_data_stores import DirDataStore class MockFileHasher(FileHasher): diff --git a/tests/test_store/test_integration.py b/tests/test_store/test_integration.py index fc26022..48e0703 100644 --- a/tests/test_store/test_integration.py +++ b/tests/test_store/test_integration.py @@ -12,7 +12,7 @@ DefaultCompositeFileHasher, ) from orcapod.hashing.string_cachers import InMemoryCacher -from orcapod.store.core import DirDataStore, NoOpDataStore +from orcapod.store.dict_data_stores import DirDataStore, NoOpDataStore def test_integration_with_cached_file_hasher(temp_dir, sample_files): diff --git a/tests/test_store/test_noop_data_store.py b/tests/test_store/test_noop_data_store.py index 42606b8..ab0eecd 100644 --- a/tests/test_store/test_noop_data_store.py +++ b/tests/test_store/test_noop_data_store.py @@ -3,7 +3,7 @@ import pytest -from orcapod.store.core import NoOpDataStore +from orcapod.store.dict_data_stores import NoOpDataStore def test_noop_data_store_memoize(): @@ -43,7 +43,7 @@ def test_noop_data_store_retrieve_memoized(): def test_noop_data_store_is_data_store_subclass(): """Test that NoOpDataStore is a subclass of DataStore.""" - from orcapod.store.core import DataStore + from orcapod.store import DataStore store = NoOpDataStore() assert isinstance(store, DataStore) diff --git a/tests/test_store/test_transfer_data_store.py b/tests/test_store/test_transfer_data_store.py index 1e8d178..6fd2add 100644 --- a/tests/test_store/test_transfer_data_store.py +++ b/tests/test_store/test_transfer_data_store.py @@ -7,8 +7,8 @@ import pytest from orcapod.hashing.types import PacketHasher -from orcapod.store.core import DirDataStore, NoOpDataStore -from orcapod.store.transfer import TransferDataStore +from orcapod.store.dict_data_stores import DirDataStore, NoOpDataStore +from orcapod.store.transfer_data_store import TransferDataStore class MockPacketHasher(PacketHasher): diff --git a/tests/test_types/test_inference/test_extract_function_data_types.py b/tests/test_types/test_inference/test_extract_function_data_types.py index a357bb0..e96fd9c 100644 --- a/tests/test_types/test_inference/test_extract_function_data_types.py +++ b/tests/test_types/test_inference/test_extract_function_data_types.py @@ -1,5 +1,5 @@ """ -Unit tests for the extract_function_data_types function. +Unit tests for the extract_function_typespecs function. This module tests the function type extraction functionality, covering: - Type inference from function annotations @@ -11,11 +11,11 @@ import pytest from collections.abc import Collection -from orcapod.types.inference import extract_function_data_types +from orcapod.types.typespec import extract_function_typespecs class TestExtractFunctionDataTypes: - """Test cases for extract_function_data_types function.""" + """Test cases for extract_function_typespecs function.""" def test_simple_annotated_function(self): """Test function with simple type annotations.""" @@ -23,7 +23,7 @@ def test_simple_annotated_function(self): def add(x: int, y: int) -> int: return x + y - input_types, output_types = extract_function_data_types(add, ["result"]) + input_types, output_types = extract_function_typespecs(add, ["result"]) assert input_types == {"x": int, "y": int} assert output_types == {"result": int} @@ -34,7 +34,7 @@ def test_multiple_return_values_tuple(self): def process(data: str) -> tuple[int, str]: return len(data), data.upper() - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( process, ["length", "upper_data"] ) @@ -54,7 +54,7 @@ def split_data(data: str) -> tuple[str, str]: # Note: This tests the case where we have multiple output keys # but the return type is list[str] (homogeneous) - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( split_data, ["first_word", "second_word"] ) @@ -71,7 +71,7 @@ def mystery_func(x: int): ValueError, match="Type for return item 'number' is not specified in output_types", ): - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( mystery_func, ["number", "text"], ) @@ -82,7 +82,7 @@ def test_input_types_override(self): def legacy_func(x, y) -> int: # No annotations return x + y - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( legacy_func, ["sum"], input_types={"x": int, "y": int} ) @@ -95,7 +95,7 @@ def test_partial_input_types_override(self): def mixed_func(x: int, y) -> int: # One annotated, one not return x + y - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( mixed_func, ["sum"], input_types={"y": float} ) @@ -108,7 +108,7 @@ def test_output_types_dict_override(self): def mystery_func(x: int) -> str: return str(x) - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( mystery_func, ["result"], output_types={"result": float} ) @@ -121,7 +121,7 @@ def test_output_types_sequence_override(self): def multi_return(data: list) -> tuple[int, float, str]: return len(data), sum(data), str(data) - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( multi_return, ["count", "total", "repr"], output_types=[int, float, str] ) @@ -134,7 +134,7 @@ def test_complex_types(self): def complex_func(x: str | None, y: int | float) -> tuple[bool, list[str]]: return bool(x), [x] if x else [] - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( complex_func, ["is_valid", "items"] ) @@ -147,7 +147,7 @@ def test_none_return_annotation(self): def side_effect_func(x: int) -> None: print(x) - input_types, output_types = extract_function_data_types(side_effect_func, []) + input_types, output_types = extract_function_typespecs(side_effect_func, []) assert input_types == {"x": int} assert output_types == {} @@ -158,7 +158,7 @@ def test_empty_parameters(self): def get_constant() -> int: return 42 - input_types, output_types = extract_function_data_types(get_constant, ["value"]) + input_types, output_types = extract_function_typespecs(get_constant, ["value"]) assert input_types == {} assert output_types == {"value": int} @@ -172,7 +172,7 @@ def bad_func(x, y: int): return x + y with pytest.raises(ValueError, match="Parameter 'x' has no type annotation"): - extract_function_data_types(bad_func, ["result"]) + extract_function_typespecs(bad_func, ["result"]) def test_return_annotation_but_no_output_keys_error(self): """Test error when function has return annotation but no output keys.""" @@ -184,7 +184,7 @@ def func_with_return(x: int) -> str: ValueError, match="Function has a return type annotation, but no return keys were specified", ): - extract_function_data_types(func_with_return, []) + extract_function_typespecs(func_with_return, []) def test_none_return_with_output_keys_error(self): """Test error when function returns None but output keys provided.""" @@ -196,7 +196,7 @@ def side_effect_func(x: int) -> None: ValueError, match="Function provides explicit return type annotation as None", ): - extract_function_data_types(side_effect_func, ["result"]) + extract_function_typespecs(side_effect_func, ["result"]) def test_single_return_multiple_keys_error(self): """Test error when single return type but multiple output keys.""" @@ -208,7 +208,7 @@ def single_return(x: int) -> str: ValueError, match="Multiple return keys were specified but return type annotation .* is not a sequence type", ): - extract_function_data_types(single_return, ["first", "second"]) + extract_function_typespecs(single_return, ["first", "second"]) def test_unparameterized_sequence_type_error(self): """Test error when return type is sequence but not parameterized.""" @@ -219,7 +219,7 @@ def bad_return(x: int) -> tuple: # tuple without types with pytest.raises( ValueError, match="is a Sequence type but does not specify item types" ): - extract_function_data_types(bad_return, ["number", "text"]) + extract_function_typespecs(bad_return, ["number", "text"]) def test_mismatched_return_types_count_error(self): """Test error when return type count doesn't match output keys count.""" @@ -230,7 +230,7 @@ def three_returns(x: int) -> tuple[int, str, float]: with pytest.raises( ValueError, match="has 3 items, but output_keys has 2 items" ): - extract_function_data_types(three_returns, ["first", "second"]) + extract_function_typespecs(three_returns, ["first", "second"]) def test_mismatched_output_types_sequence_length_error(self): """Test error when output_types sequence length doesn't match output_keys.""" @@ -242,7 +242,7 @@ def func(x: int) -> tuple[int, str]: ValueError, match="Output types collection length .* does not match return keys length", ): - extract_function_data_types( + extract_function_typespecs( func, ["first", "second"], output_types=[int, str, float], # Wrong length @@ -258,7 +258,7 @@ def no_return_annotation(x: int): ValueError, match="Type for return item 'first' is not specified in output_types", ): - extract_function_data_types(no_return_annotation, ["first", "second"]) + extract_function_typespecs(no_return_annotation, ["first", "second"]) # Edge cases @@ -268,7 +268,7 @@ def test_callable_with_args_kwargs(self): def flexible_func(x: int, *args: str, **kwargs: float) -> bool: return True - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( flexible_func, ["success"] ) @@ -284,7 +284,7 @@ def test_mixed_override_scenarios(self): def complex_func(a, b: str) -> tuple[int, str]: return len(b), b.upper() - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( complex_func, ["length", "upper"], input_types={"a": float}, @@ -300,7 +300,7 @@ def test_generic_types(self): def generic_func(data: list[int]) -> dict[str, int]: return {str(i): i for i in data} - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( generic_func, ["mapping"] ) @@ -316,7 +316,7 @@ def list_func( return str(x), x # This tests the sequence detection logic - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( list_func, ["text", "number"] ) @@ -330,7 +330,7 @@ def collection_func(x: int) -> Collection[str]: return [str(x)] # Single output key with Collection type - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( collection_func, ["result"] ) @@ -347,7 +347,7 @@ def test_empty_function(self): def empty_func(): pass - input_types, output_types = extract_function_data_types(empty_func, []) + input_types, output_types = extract_function_typespecs(empty_func, []) assert input_types == {} assert output_types == {} @@ -364,7 +364,7 @@ class Container(Generic[T]): def generic_container_func(x: Container[int]) -> Container[str]: return Container() - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( generic_container_func, ["result"] ) @@ -377,7 +377,7 @@ def test_output_types_dict_partial_override(self): def three_output_func() -> tuple[int, str, float]: return 1, "hello", 3.14 - input_types, output_types = extract_function_data_types( + input_types, output_types = extract_function_typespecs( three_output_func, ["num", "text", "decimal"], output_types={"text": bytes}, # Override only middle one From 22215ca49badc65ab894289fefad6ea7712455dc Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 05:19:37 +0000 Subject: [PATCH 13/57] fix: remove orcabridge reference --- misc/demo_redis_mocking.py | 6 +++--- src/orcapod/types/__init__.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/misc/demo_redis_mocking.py b/misc/demo_redis_mocking.py index cc18dcb..2fd1f92 100644 --- a/misc/demo_redis_mocking.py +++ b/misc/demo_redis_mocking.py @@ -72,10 +72,10 @@ def demonstrate_redis_mocking(): # Patch the Redis availability and exceptions with ( - patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True), - patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError), + patch("orcapod.hashing.string_cachers.REDIS_AVAILABLE", True), + patch("orcapod.hashing.string_cachers.redis.RedisError", MockRedisError), patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", + "orcapod.hashing.string_cachers.redis.ConnectionError", MockConnectionError, ), ): diff --git a/src/orcapod/types/__init__.py b/src/orcapod/types/__init__.py index cbcfffc..e51a6f8 100644 --- a/src/orcapod/types/__init__.py +++ b/src/orcapod/types/__init__.py @@ -1,4 +1,3 @@ -# src/orcabridge/types.py from .core import Tag, Packet, TypeSpec, PathLike, PathSet, PodFunction from .registry import TypeRegistry from .handlers import PathHandler, UUIDHandler, DateTimeHandler From 56d559a3613d916bb8d56bc04509bb7d984e53ba Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:41:04 +0000 Subject: [PATCH 14/57] refactor: rename module to match class --- ...nt_hashable.py => content_identifiable.py} | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) rename src/orcapod/hashing/{content_hashable.py => content_identifiable.py} (88%) diff --git a/src/orcapod/hashing/content_hashable.py b/src/orcapod/hashing/content_identifiable.py similarity index 88% rename from src/orcapod/hashing/content_hashable.py rename to src/orcapod/hashing/content_identifiable.py index 61eb0e5..1581e62 100644 --- a/src/orcapod/hashing/content_hashable.py +++ b/src/orcapod/hashing/content_identifiable.py @@ -1,22 +1,27 @@ - from .types import ObjectHasher from .defaults import get_default_object_hasher from typing import Any class ContentIdentifiableBase: - def __init__(self, identity_structure_hasher: ObjectHasher | None = None, label: str | None = None) -> None: + def __init__( + self, + identity_structure_hasher: ObjectHasher | None = None, + label: str | None = None, + ) -> None: """ Initialize the ContentHashable with an optional ObjectHasher. Args: identity_structure_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. """ - self.identity_structure_hasher = identity_structure_hasher or get_default_object_hasher() + self.identity_structure_hasher = ( + identity_structure_hasher or get_default_object_hasher() + ) self._label = label @property - def label(self) -> str : + def label(self) -> str: """ Get the label of this object. @@ -35,13 +40,12 @@ def label(self, label: str | None) -> None: """ self._label = label - def computed_label(self) -> str|None: + def computed_label(self) -> str | None: """ Compute a label for this object based on its content. If label is not explicitly set for this object and computed_label returns a valid value, it will be used as label of this object. """ return None - def identity_structure(self) -> Any: """ @@ -56,7 +60,6 @@ def identity_structure(self) -> Any: """ return None - def __hash__(self) -> int: """ Hash implementation that uses the identity structure if provided, @@ -72,7 +75,7 @@ def __hash__(self) -> int: return super().__hash__() return self.identity_structure_hasher.hash_to_int(structure) - + def __eq__(self, other: object) -> bool: """ Equality check that compares the identity structures of two objects. @@ -86,4 +89,4 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, ContentIdentifiableBase): return NotImplemented - return self.identity_structure() == other.identity_structure() \ No newline at end of file + return self.identity_structure() == other.identity_structure() From 59ad526334f977180820fa476a740972740f7e51 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:41:26 +0000 Subject: [PATCH 15/57] refactor: move core to legacy_core --- src/orcapod/hashing/__init__.py | 4 ++-- .../hashing/{core.py => legacy_core.py} | 22 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) rename src/orcapod/hashing/{core.py => legacy_core.py} (98%) diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index d3d83e9..2354696 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -1,4 +1,4 @@ -from .core import ( +from .legacy_core import ( HashableMixin, function_content_hash, get_function_signature, @@ -24,7 +24,7 @@ FunctionInfoExtractor, CompositeFileHasher, ) -from .content_hashable import ContentIdentifiableBase +from .content_identifiable import ContentIdentifiableBase __all__ = [ "FileHasher", diff --git a/src/orcapod/hashing/core.py b/src/orcapod/hashing/legacy_core.py similarity index 98% rename from src/orcapod/hashing/core.py rename to src/orcapod/hashing/legacy_core.py index 08fd812..cfe9c56 100644 --- a/src/orcapod/hashing/core.py +++ b/src/orcapod/hashing/legacy_core.py @@ -5,7 +5,8 @@ A library for creating stable, content-based hashes that remain consistent across Python sessions, suitable for arbitrarily nested data structures and custom objects via HashableMixin. """ -WARN_NONE_IDENTITY=False + +WARN_NONE_IDENTITY = False import hashlib import inspect import json @@ -436,11 +437,16 @@ def process_structure( if isinstance(obj, HashableMixin): logger.debug(f"Processing HashableMixin instance of type {type(obj).__name__}") return obj.content_hash() - - from .content_hashable import ContentIdentifiableBase + + from .content_identifiable import ContentIdentifiableBase + if isinstance(obj, ContentIdentifiableBase): - logger.debug(f"Processing ContentHashableBase instance of type {type(obj).__name__}") - return process_structure(obj.identity_structure(), visited, function_info_extractor) + logger.debug( + f"Processing ContentHashableBase instance of type {type(obj).__name__}" + ) + return process_structure( + obj.identity_structure(), visited, function_info_extractor + ) # Handle basic types if isinstance(obj, (str, int, float, bool)): @@ -838,7 +844,7 @@ def get_function_signature( name_override: str | None = None, include_defaults: bool = True, include_module: bool = True, - output_names: Collection[str] | None = None + output_names: Collection[str] | None = None, ) -> str: """ Get a stable string representation of a function's signature. @@ -877,9 +883,9 @@ def get_function_signature( if sig.return_annotation is not inspect.Signature.empty: parts["returns"] = sig.return_annotation - fn_string = f"{parts["module"] + "." if "module" in parts else ""}{parts["name"]}{parts["params"]}" + fn_string = f"{parts['module'] + '.' if 'module' in parts else ''}{parts['name']}{parts['params']}" if "returns" in parts: - fn_string = fn_string + f"-> {str(parts["returns"])}" + fn_string = fn_string + f"-> {str(parts['returns'])}" return fn_string From 3e0cdf40ab9669067c50f015f1c0f789ae2603a4 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:41:57 +0000 Subject: [PATCH 16/57] fix: update reference to core --- src/orcapod/store/safe_dir_data_store.py | 2 +- tests/test_hashing/test_basic_hashing.py | 2 +- tests/test_hashing/test_composite_hasher.py | 2 +- tests/test_hashing/test_path_set_hasher.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/orcapod/store/safe_dir_data_store.py b/src/orcapod/store/safe_dir_data_store.py index 0f0ce6a..7e16f63 100644 --- a/src/orcapod/store/safe_dir_data_store.py +++ b/src/orcapod/store/safe_dir_data_store.py @@ -205,7 +205,7 @@ def __init__( def _get_output_dir(self, function_name, content_hash, packet): """Get the output directory for a specific packet""" - from orcapod.hashing.core import hash_dict + from orcapod.hashing.legacy_core import hash_dict packet_hash = hash_dict(packet) return self.store_dir / function_name / content_hash / str(packet_hash) diff --git a/tests/test_hashing/test_basic_hashing.py b/tests/test_hashing/test_basic_hashing.py index df90a1a..c67723a 100644 --- a/tests/test_hashing/test_basic_hashing.py +++ b/tests/test_hashing/test_basic_hashing.py @@ -1,4 +1,4 @@ -from orcapod.hashing.core import ( +from orcapod.hashing.legacy_core import ( HashableMixin, hash_to_hex, hash_to_int, diff --git a/tests/test_hashing/test_composite_hasher.py b/tests/test_hashing/test_composite_hasher.py index 1cbe386..f92cfea 100644 --- a/tests/test_hashing/test_composite_hasher.py +++ b/tests/test_hashing/test_composite_hasher.py @@ -5,7 +5,7 @@ import pytest -from orcapod.hashing.core import hash_to_hex +from orcapod.hashing.legacy_core import hash_to_hex from orcapod.hashing.file_hashers import BasicFileHasher, DefaultCompositeFileHasher from orcapod.hashing.types import FileHasher, PacketHasher, PathSetHasher diff --git a/tests/test_hashing/test_path_set_hasher.py b/tests/test_hashing/test_path_set_hasher.py index 999cc2a..65e626a 100644 --- a/tests/test_hashing/test_path_set_hasher.py +++ b/tests/test_hashing/test_path_set_hasher.py @@ -8,7 +8,7 @@ import pytest -import orcapod.hashing.core +import orcapod.hashing.legacy_core from orcapod.hashing.file_hashers import DefaultPathsetHasher from orcapod.hashing.types import FileHasher @@ -35,7 +35,7 @@ def create_temp_file(content="test content"): # Store original function for restoration -original_hash_pathset = orcapod.hashing.core.hash_pathset +original_hash_pathset = orcapod.hashing.legacy_core.hash_pathset # Custom implementation of hash_pathset for tests that doesn't check for file existence @@ -46,7 +46,7 @@ def mock_hash_pathset( from collections.abc import Collection from os import PathLike - from orcapod.hashing.core import hash_to_hex + from orcapod.hashing.legacy_core import hash_to_hex from orcapod.utils.name import find_noncolliding_name # If file_hasher is None, we'll need to handle it differently From 50e07722d07cb19b3c5a1ab555c7acf0783bbd61 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:42:21 +0000 Subject: [PATCH 17/57] refactor: rename semantic arrow hasher module to generic arrow hashers --- ...antic_arrow_hasher.py => arrow_hashers.py} | 81 +------------------ 1 file changed, 2 insertions(+), 79 deletions(-) rename src/orcapod/hashing/{semantic_arrow_hasher.py => arrow_hashers.py} (70%) diff --git a/src/orcapod/hashing/semantic_arrow_hasher.py b/src/orcapod/hashing/arrow_hashers.py similarity index 70% rename from src/orcapod/hashing/semantic_arrow_hasher.py rename to src/orcapod/hashing/arrow_hashers.py index f3682ed..728b904 100644 --- a/src/orcapod/hashing/semantic_arrow_hasher.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -1,87 +1,10 @@ import hashlib -import os -from typing import Any, Protocol -from abc import ABC, abstractmethod +from typing import Any import pyarrow as pa import pyarrow.ipc as ipc from io import BytesIO import polars as pl - - -class SemanticTypeHasher(Protocol): - """Abstract base class for semantic type-specific hashers.""" - - @abstractmethod - def hash_column(self, column: pa.Array) -> bytes: - """Hash a column with this semantic type and return the hash bytes.""" - pass - - -class PathHasher(SemanticTypeHasher): - """Hasher for Path semantic type columns - hashes file contents.""" - - def __init__(self, chunk_size: int = 8192, handle_missing: str = "error"): - """ - Initialize PathHasher. - - Args: - chunk_size: Size of chunks to read files in bytes - handle_missing: How to handle missing files ('error', 'skip', 'null_hash') - """ - self.chunk_size = chunk_size - self.handle_missing = handle_missing - - def _hash_file_content(self, file_path: str) -> str: - """Hash the content of a single file and return hex string.""" - import os - - try: - if not os.path.exists(file_path): - if self.handle_missing == "error": - raise FileNotFoundError(f"File not found: {file_path}") - elif self.handle_missing == "skip": - return hashlib.sha256(b"").hexdigest() - elif self.handle_missing == "null_hash": - return hashlib.sha256(b"").hexdigest() - - hasher = hashlib.sha256() - - # Read file in chunks to handle large files efficiently - with open(file_path, "rb") as f: - while chunk := f.read(self.chunk_size): - hasher.update(chunk) - - return hasher.hexdigest() - - except (IOError, OSError, PermissionError) as e: - if self.handle_missing == "error": - raise IOError(f"Cannot read file {file_path}: {e}") - else: # skip or null_hash - error_msg = f"" - return hashlib.sha256(error_msg.encode("utf-8")).hexdigest() - - def hash_column(self, column: pa.Array) -> pa.Array: - """ - Replace path column with file content hashes. - Returns a new array where each path is replaced with its file content hash. - """ - - # Convert to python list for processing - paths = column.to_pylist() - - # Hash each file's content individually - content_hashes = [] - for path in paths: - if path is not None: - # Normalize path for consistency - normalized_path = os.path.normpath(str(path)) - file_content_hash = self._hash_file_content(normalized_path) - content_hashes.append(file_content_hash) - else: - content_hashes.append(None) # Preserve nulls - - # Return new array with content hashes instead of paths - return pa.array(content_hashes) +from .types import SemanticTypeHasher class SemanticArrowHasher: From 33103b8668602a1b66f94a005b9d0f0455dc1c08 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:44:04 +0000 Subject: [PATCH 18/57] refactor: rename variables to typespec --- src/orcapod/hashing/function_info_extractors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/orcapod/hashing/function_info_extractors.py b/src/orcapod/hashing/function_info_extractors.py index 2c32f05..816208b 100644 --- a/src/orcapod/hashing/function_info_extractors.py +++ b/src/orcapod/hashing/function_info_extractors.py @@ -14,8 +14,8 @@ def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_types: TypeSpec | None = None, - output_types: TypeSpec | None = None, + input_typespec: TypeSpec | None = None, + output_typespec: TypeSpec | None = None, ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") @@ -38,8 +38,8 @@ def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_types: TypeSpec | None = None, - output_types: TypeSpec | None = None, + input_typespec: TypeSpec | None = None, + output_typespec: TypeSpec | None = None, ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") From e35b024ca0d4e9cea92d4a178671e2d12dd1ffda Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:44:29 +0000 Subject: [PATCH 19/57] feat: collect refined hashing functions --- src/orcapod/hashing/hash_utils.py | 304 ++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 src/orcapod/hashing/hash_utils.py diff --git a/src/orcapod/hashing/hash_utils.py b/src/orcapod/hashing/hash_utils.py new file mode 100644 index 0000000..df2d435 --- /dev/null +++ b/src/orcapod/hashing/hash_utils.py @@ -0,0 +1,304 @@ +from typing import Any +from .function_info_extractors import FunctionInfoExtractor +import logging +import json +from uuid import UUID +from pathlib import Path +from collections.abc import Mapping, Collection +import hashlib +import xxhash +import zlib + +logger = logging.getLogger(__name__) + + +def serialize_through_json(processed_obj) -> bytes: + """ + Create a deterministic string representation of a processed object structure. + + Args: + processed_obj: The processed object to serialize + + Returns: + A bytes object ready for hashing + """ + # TODO: add type check of processed obj + return json.dumps(processed_obj, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + + +def process_structure( + obj: Any, + visited: set[int] | None = None, + function_info_extractor: FunctionInfoExtractor | None = None, + force_hash: bool = False, +) -> Any: + """ + Recursively process a structure to prepare it for hashing. + + Args: + obj: The object or structure to process + visited: Set of object ids already visited (to handle circular references) + function_info_extractor: FunctionInfoExtractor to be used for extracting necessary function representation + + Returns: + A processed version of the structure suitable for stable hashing + """ + # Initialize the visited set if this is the top-level call + if visited is None: + visited = set() + + # Check for circular references - use object's memory address + # NOTE: While id() is not stable across sessions, we only use it within a session + # to detect circular references, not as part of the final hash + obj_id = id(obj) + if obj_id in visited: + logger.debug( + f"Detected circular reference for object of type {type(obj).__name__}" + ) + return "CircularRef" # Don't include the actual id in hash output + + # For objects that could contain circular references, add to visited + if isinstance(obj, (dict, list, tuple, set)) or not isinstance( + obj, (str, int, float, bool, type(None)) + ): + visited.add(obj_id) + + # Handle None + if obj is None: + return None + + from .content_identifiable import ContentIdentifiableBase + + if isinstance(obj, ContentIdentifiableBase): + logger.debug( + f"Processing ContentHashableBase instance of type {type(obj).__name__}" + ) + # replace the object with expanded identity structure and re-process + return process_structure( + obj.identity_structure(), visited, function_info_extractor + ) + + # Handle basic types + if isinstance(obj, (str, int, float, bool)): + return obj + + # Handle bytes and bytearray + if isinstance(obj, (bytes, bytearray)): + logger.debug( + f"Converting bytes/bytearray of length {len(obj)} to hex representation" + ) + return obj.hex() + + # Handle Path objects + if isinstance(obj, Path): + logger.debug(f"Converting Path object to string: {obj}") + return str(obj) + + # Handle UUID objects + if isinstance(obj, UUID): + logger.debug(f"Converting UUID to string: {obj}") + return str(obj) + + # Handle named tuples (which are subclasses of tuple) + if hasattr(obj, "_fields") and isinstance(obj, tuple): + logger.debug(f"Processing named tuple of type {type(obj).__name__}") + # For namedtuples, convert to dict and then process + d = {field: getattr(obj, field) for field in obj._fields} # type: ignore + return process_structure(d, visited, function_info_extractor) + + # Handle mappings (dict-like objects) + if isinstance(obj, Mapping): + # Process both keys and values + processed_items = [ + ( + process_structure(k, visited, function_info_extractor), + process_structure(v, visited, function_info_extractor), + ) + for k, v in obj.items() + ] + + # Sort by the processed keys for deterministic order + processed_items.sort(key=lambda x: str(x[0])) + + # Create a new dictionary with string keys based on processed keys + # TODO: consider checking for possibly problematic values in processed_k + # and issue a warning + return { + str(processed_k): processed_v + for processed_k, processed_v in processed_items + } + + # Handle sets and frozensets + if isinstance(obj, (set, frozenset)): + logger.debug( + f"Processing set/frozenset of type {type(obj).__name__} with {len(obj)} items" + ) + # Process each item first, then sort the processed results + processed_items = [ + process_structure(item, visited, function_info_extractor) for item in obj + ] + return sorted(processed_items, key=str) + + # Handle collections (list-like objects) + if isinstance(obj, Collection): + logger.debug( + f"Processing collection of type {type(obj).__name__} with {len(obj)} items" + ) + return [ + process_structure(item, visited, function_info_extractor) for item in obj + ] + + # For functions, use the function_content_hash + if callable(obj) and hasattr(obj, "__code__"): + logger.debug(f"Processing function: {getattr(obj, '__name__')}") + if function_info_extractor is not None: + # Use the extractor to get a stable representation + function_info = function_info_extractor.extract_function_info(obj) + logger.debug(f"Extracted function info: {function_info} for {obj.__name__}") + + # simply return the function info as a stable representation + return function_info + else: + raise ValueError( + f"Function {obj} encountered during processing but FunctionInfoExtractor is missing" + ) + + # For other objects, attempt to create deterministic representation only if force_hash=True + class_name = obj.__class__.__name__ + module_name = obj.__class__.__module__ + if force_hash: + try: + import re + + logger.debug( + f"Processing generic object of type {module_name}.{class_name}" + ) + + # Try to get a stable dict representation if possible + if hasattr(obj, "__dict__"): + # Sort attributes to ensure stable order + attrs = sorted( + (k, v) for k, v in obj.__dict__.items() if not k.startswith("_") + ) + # Limit to first 10 attributes to avoid extremely long representations + if len(attrs) > 10: + logger.debug( + f"Object has {len(attrs)} attributes, limiting to first 10" + ) + attrs = attrs[:10] + attr_strs = [f"{k}={type(v).__name__}" for k, v in attrs] + obj_repr = f"{{{', '.join(attr_strs)}}}" + else: + # Get basic repr but remove memory addresses + logger.debug( + "Object has no __dict__, using repr() with memory address removal" + ) + obj_repr = repr(obj) + if len(obj_repr) > 1000: + logger.debug( + f"Object repr is {len(obj_repr)} chars, truncating to 1000" + ) + obj_repr = obj_repr[:1000] + "..." + # Remove memory addresses which look like '0x7f9a1c2b3d4e' + obj_repr = re.sub(r" at 0x[0-9a-f]+", " at 0xMEMADDR", obj_repr) + + return f"{module_name}.{class_name}-{obj_repr}" + except Exception as e: + # Last resort - use class name only + logger.warning(f"Failed to process object representation: {e}") + try: + return f"Object-{obj.__class__.__module__}.{obj.__class__.__name__}" + except AttributeError: + logger.error("Could not determine object class, using UnknownObject") + return "UnknownObject" + else: + raise ValueError( + f"Processing of {obj} of type {module_name}.{class_name} is not supported" + ) + + +def hash_object( + obj: Any, + function_info_extractor: FunctionInfoExtractor | None = None, +) -> bytes: + # Process the object to handle nested structures and HashableMixin instances + processed = process_structure(obj, function_info_extractor=function_info_extractor) + + # Serialize the processed structure + json_str = json.dumps(processed, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + logger.debug( + f"Successfully serialized {type(obj).__name__} using custom serializer" + ) + + # Create the hash + return hashlib.sha256(json_str).digest() + + +def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> bytes: + """ + Calculate the hash of a file using the specified algorithm. + + Parameters: + file_path (str): Path to the file to hash + algorithm (str): Hash algorithm to use - options include: + 'md5', 'sha1', 'sha256', 'sha512', 'xxh64', 'crc32', 'hash_path' + buffer_size (int): Size of chunks to read from the file at a time + + Returns: + str: Hexadecimal digest of the hash + """ + # Verify the file exists + if not Path(file_path).is_file(): + raise FileNotFoundError(f"The file {file_path} does not exist") + + # Handle special case for 'hash_path' algorithm + if algorithm == "hash_path": + # Hash the name of the file instead of its content + # This is useful for cases where the file content is well known or + # not relevant + hasher = hashlib.sha256() + hasher.update(file_path.encode("utf-8")) + return hasher.digest() + + # Handle non-cryptographic hash functions + if algorithm == "xxh64": + hasher = xxhash.xxh64() + with open(file_path, "rb") as file: + while True: + data = file.read(buffer_size) + if not data: + break + hasher.update(data) + return hasher.digest() + + if algorithm == "crc32": + crc = 0 + with open(file_path, "rb") as file: + while True: + data = file.read(buffer_size) + if not data: + break + crc = zlib.crc32(data, crc) + return (crc & 0xFFFFFFFF).to_bytes(4, byteorder="big") + + # Handle cryptographic hash functions from hashlib + try: + hasher = hashlib.new(algorithm) + except ValueError: + valid_algorithms = ", ".join(sorted(hashlib.algorithms_available)) + raise ValueError( + f"Invalid algorithm: {algorithm}. Available algorithms: {valid_algorithms}, xxh64, crc32" + ) + + with open(file_path, "rb") as file: + while True: + data = file.read(buffer_size) + if not data: + break + hasher.update(data) + + return hasher.digest() From 02412d08b4834a2d2c183b538a985aa277a7318a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:44:51 +0000 Subject: [PATCH 20/57] feat: collect semantic type hashsers into a module --- src/orcapod/hashing/semantic_type_hashers.py | 64 ++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 src/orcapod/hashing/semantic_type_hashers.py diff --git a/src/orcapod/hashing/semantic_type_hashers.py b/src/orcapod/hashing/semantic_type_hashers.py new file mode 100644 index 0000000..36dfd53 --- /dev/null +++ b/src/orcapod/hashing/semantic_type_hashers.py @@ -0,0 +1,64 @@ +from .types import SemanticTypeHasher, FileHasher +import os +import hashlib +import pyarrow as pa + + +class PathHasher(SemanticTypeHasher): + """Hasher for Path semantic type columns - hashes file contents.""" + + def __init__(self, file_hasher: FileHasher, handle_missing: str = "error"): + """ + Initialize PathHasher. + + Args: + chunk_size: Size of chunks to read files in bytes + handle_missing: How to handle missing files ('error', 'skip', 'null_hash') + """ + self.file_hasher = file_hasher + self.handle_missing = handle_missing + + def _hash_file_content(self, file_path: str) -> str: + """Hash the content of a single file and return hex string.""" + import os + + try: + if not os.path.exists(file_path): + if self.handle_missing == "error": + raise FileNotFoundError(f"File not found: {file_path}") + elif self.handle_missing == "skip": + return hashlib.sha256(b"").hexdigest() + elif self.handle_missing == "null_hash": + return hashlib.sha256(b"").hexdigest() + + return self.file_hasher.hash_file(file_path).hex() + + except (IOError, OSError, PermissionError) as e: + if self.handle_missing == "error": + raise IOError(f"Cannot read file {file_path}: {e}") + else: # skip or null_hash + error_msg = f"" + return hashlib.sha256(error_msg.encode("utf-8")).hexdigest() + + def hash_column(self, column: pa.Array) -> pa.Array: + """ + Replace path column with file content hashes. + Returns a new array where each path is replaced with its file content hash. + """ + + # Convert to python list for processing + paths = column.to_pylist() + + # Hash each file's content individually + content_hashes = [] + for path in paths: + if path is not None: + # Normalize path for consistency + normalized_path = os.path.normpath(str(path)) + file_content_hash = self._hash_file_content(normalized_path) + content_hashes.append(file_content_hash) + else: + content_hashes.append(None) # Preserve nulls + + # Return new array with content hashes instead of paths + return pa.array(content_hashes) From 1e9067968173812f4eaf6e0e71be83e282ea375f Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:45:19 +0000 Subject: [PATCH 21/57] refactor: make file hasher return bytes --- src/orcapod/hashing/file_hashers.py | 75 ++++++++++++++++++++--------- src/orcapod/hashing/types.py | 17 ++++--- 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/src/orcapod/hashing/file_hashers.py b/src/orcapod/hashing/file_hashers.py index 77833ee..58076ac 100644 --- a/src/orcapod/hashing/file_hashers.py +++ b/src/orcapod/hashing/file_hashers.py @@ -1,4 +1,5 @@ -from orcapod.hashing.core import hash_file, hash_pathset, hash_packet +from orcapod.hashing import legacy_core +from orcapod.hashing.hash_utils import hash_file from orcapod.hashing.types import ( FileHasher, PathSetHasher, @@ -8,8 +9,6 @@ from orcapod.types import Packet, PathLike, PathSet -# Completely unnecessary to inherit from FileHasher, but this -# allows for type checking based on isinstance class BasicFileHasher: """Basic implementation for file hashing.""" @@ -21,7 +20,7 @@ def __init__( self.algorithm = algorithm self.buffer_size = buffer_size - def hash_file(self, file_path: PathLike) -> str: + def hash_file(self, file_path: PathLike) -> bytes: return hash_file( file_path, algorithm=self.algorithm, buffer_size=self.buffer_size ) @@ -38,18 +37,38 @@ def __init__( self.file_hasher = file_hasher self.string_cacher = string_cacher - def hash_file(self, file_path: PathLike) -> str: + def hash_file(self, file_path: PathLike) -> bytes: cache_key = f"file:{file_path}" cached_value = self.string_cacher.get_cached(cache_key) if cached_value is not None: - return cached_value + return bytes.fromhex(cached_value) value = self.file_hasher.hash_file(file_path) - self.string_cacher.set_cached(cache_key, value) + self.string_cacher.set_cached(cache_key, value.hex()) return value -class DefaultPathsetHasher: +# ----------------Legacy implementations for backward compatibility----------------- + + +class LegacyFileHasher: + def __init__( + self, + algorithm: str = "sha256", + buffer_size: int = 65536, + ): + self.algorithm = algorithm + self.buffer_size = buffer_size + + def hash_file(self, file_path: PathLike) -> bytes: + return bytes.fromhex( + legacy_core.hash_file( + file_path, algorithm=self.algorithm, buffer_size=self.buffer_size + ), + ) + + +class LegacyPathsetHasher: """Default pathset hasher that composes file hashing.""" def __init__( @@ -60,16 +79,21 @@ def __init__( self.file_hasher = file_hasher self.char_count = char_count - def hash_pathset(self, pathset: PathSet) -> str: + def _hash_file_to_hex(self, file_path: PathLike) -> str: + return self.file_hasher.hash_file(file_path).hex() + + def hash_pathset(self, pathset: PathSet) -> bytes: """Hash a pathset using the injected file hasher.""" - return hash_pathset( - pathset, - char_count=self.char_count, - file_hasher=self.file_hasher.hash_file, # Inject the method + return bytes.fromhex( + legacy_core.hash_pathset( + pathset, + char_count=self.char_count, + file_hasher=self._hash_file_to_hex, # Inject the method + ) ) -class DefaultPacketHasher: +class LegacyPacketHasher: """Default packet hasher that composes pathset hashing.""" def __init__( @@ -82,19 +106,22 @@ def __init__( self.char_count = char_count self.prefix = prefix + def _hash_pathset_to_hex(self, pathset: PathSet): + return self.pathset_hasher.hash_pathset(pathset).hex() + def hash_packet(self, packet: Packet) -> str: """Hash a packet using the injected pathset hasher.""" - hash_str = hash_packet( + hash_str = legacy_core.hash_packet( packet, char_count=self.char_count, prefix_algorithm=False, # Will apply prefix on our own - pathset_hasher=self.pathset_hasher.hash_pathset, # Inject the method + pathset_hasher=self._hash_pathset_to_hex, # Inject the method ) return f"{self.prefix}-{hash_str}" if self.prefix else hash_str # Convenience composite implementation -class DefaultCompositeFileHasher: +class LegacyCompositeFileHasher: """Composite hasher that implements all interfaces.""" def __init__( @@ -104,15 +131,15 @@ def __init__( packet_prefix: str = "", ): self.file_hasher = file_hasher - self.pathset_hasher = DefaultPathsetHasher(self.file_hasher, char_count) - self.packet_hasher = DefaultPacketHasher( + self.pathset_hasher = LegacyPathsetHasher(self.file_hasher, char_count) + self.packet_hasher = LegacyPacketHasher( self.pathset_hasher, char_count, packet_prefix ) - def hash_file(self, file_path: PathLike) -> str: + def hash_file(self, file_path: PathLike) -> bytes: return self.file_hasher.hash_file(file_path) - def hash_pathset(self, pathset: PathSet) -> str: + def hash_pathset(self, pathset: PathSet) -> bytes: return self.pathset_hasher.hash_pathset(pathset) def hash_packet(self, packet: Packet) -> str: @@ -120,7 +147,7 @@ def hash_packet(self, packet: Packet) -> str: # Factory for easy construction -class PathLikeHasherFactory: +class LegacyPathLikeHasherFactory: """Factory for creating various hasher combinations.""" @staticmethod @@ -132,7 +159,7 @@ def create_basic_composite( """Create a basic composite hasher.""" file_hasher = BasicFileHasher(algorithm, buffer_size) # use algorithm as the prefix for the packet hasher - return DefaultCompositeFileHasher( + return LegacyCompositeFileHasher( file_hasher, char_count, packet_prefix=algorithm ) @@ -146,7 +173,7 @@ def create_cached_composite( """Create a composite hasher with file caching.""" basic_file_hasher = BasicFileHasher(algorithm, buffer_size) cached_file_hasher = CachedFileHasher(basic_file_hasher, string_cacher) - return DefaultCompositeFileHasher( + return LegacyCompositeFileHasher( cached_file_hasher, char_count, packet_prefix=algorithm ) diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index abae409..310b5a2 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -29,6 +29,7 @@ def identity_structure(self) -> Any: class ObjectHasher(ABC): """Abstract class for general object hashing.""" + # TODO: consider more explicitly stating types of objects accepted @abstractmethod def hash(self, obj: Any) -> bytes: """ @@ -81,7 +82,7 @@ def hash_to_uuid( class FileHasher(Protocol): """Protocol for file-related hashing.""" - def hash_file(self, file_path: PathLike) -> str: ... + def hash_file(self, file_path: PathLike) -> bytes: ... # Higher-level operations that compose file hashing @@ -89,12 +90,7 @@ def hash_file(self, file_path: PathLike) -> str: ... class PathSetHasher(Protocol): """Protocol for hashing pathsets (files, directories, collections).""" - def hash_pathset(self, pathset: PathSet) -> str: ... - - -@runtime_checkable -class SemanticHasher(Protocol): - pass + def hash_pathset(self, pathset: PathSet) -> bytes: ... @runtime_checkable @@ -142,3 +138,10 @@ def extract_function_info( ) -> dict[str, Any]: ... +class SemanticTypeHasher(Protocol): + """Abstract base class for semantic type-specific hashers.""" + + @abstractmethod + def hash_column(self, column: pa.Array) -> list[bytes]: + """Hash a column with this semantic type and return the hash bytes.""" + pass From 78fdead1c5c84cd08a6f03789eb61276e014d1a9 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:46:03 +0000 Subject: [PATCH 22/57] feat: add new defaut object hasher --- src/orcapod/hashing/defaults.py | 34 ++++++++++++++++++--------- src/orcapod/hashing/object_hashers.py | 30 ++++++++++++++++++++--- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 85e1405..1f3aca2 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -1,26 +1,27 @@ # A collection of utility function that provides a "default" implementation of hashers. # This is often used as the fallback hasher in the library code. -from orcapod.hashing.types import CompositeFileHasher, ArrowHasher -from orcapod.hashing.file_hashers import PathLikeHasherFactory +from orcapod.hashing.types import CompositeFileHasher, ArrowHasher, FileHasher +from orcapod.hashing.file_hashers import BasicFileHasher, LegacyPathLikeHasherFactory from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.hashing.object_hashers import ObjectHasher -from orcapod.hashing.object_hashers import LegacyObjectHasher +from orcapod.hashing.object_hashers import DefaultObjectHasher, LegacyObjectHasher from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory -from orcapod.hashing.semantic_arrow_hasher import SemanticArrowHasher, PathHasher +from orcapod.hashing.arrow_hashers import SemanticArrowHasher +from orcapod.hashing.semantic_type_hashers import PathHasher def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: if with_cache: # use unlimited caching string_cacher = InMemoryCacher(max_size=None) - return PathLikeHasherFactory.create_cached_composite(string_cacher) - return PathLikeHasherFactory.create_basic_composite() + return LegacyPathLikeHasherFactory.create_cached_composite(string_cacher) + return LegacyPathLikeHasherFactory.create_basic_composite() def get_default_composite_file_hasher_with_cacher(cacher=None) -> CompositeFileHasher: if cacher is None: cacher = InMemoryCacher(max_size=None) - return PathLikeHasherFactory.create_cached_composite(cacher) + return LegacyPathLikeHasherFactory.create_cached_composite(cacher) def get_default_object_hasher() -> ObjectHasher: @@ -29,15 +30,26 @@ def get_default_object_hasher() -> ObjectHasher: strategy="signature" ) ) - return LegacyObjectHasher( - char_count=32, function_info_extractor=function_info_extractor + return DefaultObjectHasher(function_info_extractor=function_info_extractor) + + +def get_legacy_object_hasher() -> ObjectHasher: + function_info_extractor = ( + FunctionInfoExtractorFactory.create_function_info_extractor( + strategy="signature" + ) ) + return LegacyObjectHasher(function_info_extractor=function_info_extractor) def get_default_arrow_hasher( - chunk_size: int = 8192, handle_missing: str = "error" + chunk_size: int = 8192, + handle_missing: str = "error", + file_hasher: FileHasher | None = None, ) -> ArrowHasher: + if file_hasher is None: + file_hasher = BasicFileHasher() hasher = SemanticArrowHasher(chunk_size=chunk_size, handle_missing=handle_missing) # register semantic hasher for Path - hasher.register_semantic_hasher("Path", PathHasher()) + hasher.register_semantic_hasher("Path", PathHasher(file_hasher=file_hasher)) return hasher diff --git a/src/orcapod/hashing/object_hashers.py b/src/orcapod/hashing/object_hashers.py index a3f4b39..7e35ccb 100644 --- a/src/orcapod/hashing/object_hashers.py +++ b/src/orcapod/hashing/object_hashers.py @@ -1,5 +1,31 @@ +from polars import Object from .types import FunctionInfoExtractor, ObjectHasher -from .core import legacy_hash +from .legacy_core import legacy_hash +from .hash_utils import hash_object + + +class DefaultObjectHasher(ObjectHasher): + """ + Default object hasher used throughout the codebase. + """ + + def __init__( + self, + function_info_extractor: FunctionInfoExtractor | None = None, + ): + self.function_info_extractor = function_info_extractor + + def hash(self, obj: object) -> bytes: + """ + Hash an object to a byte representation. + + Args: + obj (object): The object to hash. + + Returns: + bytes: The byte representation of the hash. + """ + return hash_object(obj, function_info_extractor=self.function_info_extractor) class LegacyObjectHasher(ObjectHasher): @@ -13,7 +39,6 @@ class LegacyObjectHasher(ObjectHasher): def __init__( self, - char_count: int | None = 32, function_info_extractor: FunctionInfoExtractor | None = None, ): """ @@ -22,7 +47,6 @@ def __init__( Args: function_info_extractor (FunctionInfoExtractor | None): Optional extractor for function information. This must be provided if an object containing function information is to be hashed. """ - self.char_count = char_count self.function_info_extractor = function_info_extractor def hash(self, obj: object) -> bytes: From 3dcaa0b377f67ab96166b532ea54fc7bbf06f9fe Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 07:46:13 +0000 Subject: [PATCH 23/57] test: update ref --- tests/test_hashing/test_process_structure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hashing/test_process_structure.py b/tests/test_hashing/test_process_structure.py index 933e2dc..2967ed4 100644 --- a/tests/test_hashing/test_process_structure.py +++ b/tests/test_hashing/test_process_structure.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Any -from orcapod.hashing.core import HashableMixin, hash_to_hex, process_structure +from orcapod.hashing.legacy_core import HashableMixin, hash_to_hex, process_structure # Define a simple HashableMixin class for testing From 89ddd76bf2cf0cf130a9f633f1523c495da8cc79 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 27 Jun 2025 16:31:08 +0000 Subject: [PATCH 24/57] fix: handle type vars in process_structure --- src/orcapod/hashing/hash_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/orcapod/hashing/hash_utils.py b/src/orcapod/hashing/hash_utils.py index df2d435..7fee36b 100644 --- a/src/orcapod/hashing/hash_utils.py +++ b/src/orcapod/hashing/hash_utils.py @@ -165,6 +165,11 @@ def process_structure( f"Function {obj} encountered during processing but FunctionInfoExtractor is missing" ) + # handle data types + if isinstance(obj, type): + logger.debug(f"Processing class/type: {obj.__name__}") + return f"type:{obj.__class__.__module__}.{obj.__class__.__name__}" + # For other objects, attempt to create deterministic representation only if force_hash=True class_name = obj.__class__.__name__ module_name = obj.__class__.__module__ @@ -204,12 +209,12 @@ def process_structure( # Remove memory addresses which look like '0x7f9a1c2b3d4e' obj_repr = re.sub(r" at 0x[0-9a-f]+", " at 0xMEMADDR", obj_repr) - return f"{module_name}.{class_name}-{obj_repr}" + return f"{module_name}.{class_name}:{obj_repr}" except Exception as e: # Last resort - use class name only logger.warning(f"Failed to process object representation: {e}") try: - return f"Object-{obj.__class__.__module__}.{obj.__class__.__name__}" + return f"object:{obj.__class__.__module__}.{obj.__class__.__name__}" except AttributeError: logger.error("Could not determine object class, using UnknownObject") return "UnknownObject" From 905f91524b34e7d1cacfbaf9dbf868e20a1ec685 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 30 Jun 2025 19:39:50 +0000 Subject: [PATCH 25/57] wip: use new schema system --- src/orcapod/core/base.py | 7 +- src/orcapod/core/operators.py | 18 +- src/orcapod/core/pod.py | 6 +- src/orcapod/core/sources.py | 2 +- src/orcapod/core/streams.py | 10 +- src/orcapod/hashing/__init__.py | 4 +- src/orcapod/hashing/arrow_hashers.py | 77 ++- src/orcapod/hashing/arrow_utils.py | 403 +++++++++++++++ src/orcapod/hashing/defaults.py | 47 +- src/orcapod/hashing/file_hashers.py | 18 +- src/orcapod/hashing/semantic_type_hashers.py | 35 +- src/orcapod/hashing/string_cachers.py | 26 +- src/orcapod/hashing/types.py | 16 +- src/orcapod/hashing/versioned_hashers.py | 71 +++ src/orcapod/pipeline/wrappers.py | 121 ++--- src/orcapod/types/__init__.py | 18 +- src/orcapod/types/core.py | 29 +- src/orcapod/types/packet_converter.py | 177 +++++++ src/orcapod/types/packets.py | 241 +++++++++ src/orcapod/types/registry.py | 437 ---------------- src/orcapod/types/schemas.py | 267 ++++++++++ ...{handlers.py => semantic_type_handlers.py} | 12 +- src/orcapod/types/semantic_type_registry.py | 468 ++++++++++++++++++ .../types/{typespec.py => typespec_utils.py} | 59 ++- src/orcapod/utils/stream_utils.py | 50 -- tests/test_hashing/test_composite_hasher.py | 156 ------ tests/test_store/test_transfer_data_store.py | 1 - .../test_extract_function_data_types.py | 2 +- 28 files changed, 1945 insertions(+), 833 deletions(-) create mode 100644 src/orcapod/hashing/arrow_utils.py create mode 100644 src/orcapod/hashing/versioned_hashers.py create mode 100644 src/orcapod/types/packet_converter.py create mode 100644 src/orcapod/types/packets.py delete mode 100644 src/orcapod/types/registry.py create mode 100644 src/orcapod/types/schemas.py rename src/orcapod/types/{handlers.py => semantic_type_handlers.py} (92%) create mode 100644 src/orcapod/types/semantic_type_registry.py rename src/orcapod/types/{typespec.py => typespec_utils.py} (83%) delete mode 100644 tests/test_hashing/test_composite_hasher.py diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index fc18b48..9a30873 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -10,8 +10,7 @@ from orcapod.hashing import ContentIdentifiableBase from orcapod.types import Packet, Tag, TypeSpec -from orcapod.utils.stream_utils import get_typespec - +from orcapod.types.typespec import get_typespec_from_dict import logging @@ -151,7 +150,7 @@ def types( return None, None tag, packet = next(iter(self(*streams))) - return get_typespec(tag), get_typespec(packet) + return get_typespec_from_dict(tag), get_typespec_from_dict(packet) def claims_unique_tags( self, *streams: "SyncStream", trigger_run: bool = False @@ -391,7 +390,7 @@ def types(self, *, trigger_run=False) -> tuple[TypeSpec | None, TypeSpec | None] # otherwise, use the keys from the first packet in the stream # note that this may be computationally expensive tag, packet = next(iter(self)) - return tag_types or get_typespec(tag), packet_types or get_typespec(packet) + return tag_types or get_typespec_from_dict(tag), packet_types or get_typespec_from_dict(packet) def claims_unique_tags(self, *, trigger_run=False) -> bool | None: """ diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index 53ecacc..c68f34f 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -4,6 +4,7 @@ from typing import Any from orcapod.types import Packet, Tag, TypeSpec +from orcapod.types.typespec_utils import union_typespecs, intersection_typespecs from orcapod.hashing import function_content_hash, hash_function from orcapod.core.base import Kernel, SyncStream, Operator from orcapod.core.streams import SyncStreamFromGenerator @@ -11,11 +12,8 @@ batch_packet, batch_tags, check_packet_compatibility, - intersection_typespecs, join_tags, semijoin_tags, - union_typespecs, - intersection_typespecs, fill_missing ) @@ -268,7 +266,7 @@ def generator() -> Iterator[tuple[Tag, Packet]]: raise ValueError( f"Packets are not compatible: {left_packet} and {right_packet}" ) - yield joined_tag, {**left_packet, **right_packet} + yield joined_tag, Packet({**left_packet, **right_packet}) return SyncStreamFromGenerator(generator) @@ -307,7 +305,7 @@ def generator(): ) # match is found - remove the packet from the inner stream inner_stream.pop(idx) - yield joined_tag, {**outer_packet, **inner_packet} + yield joined_tag, Packet({**outer_packet, **inner_packet}) # if enough matches found, move onto the next outer stream packet break @@ -402,11 +400,11 @@ def forward(self, *streams: SyncStream) -> SyncStream: def generator(): for tag, packet in stream: if self.drop_unmapped: - packet = { + packet = Packet({ v: packet[k] for k, v in self.key_map.items() if k in packet - } + }) else: - packet = {self.key_map.get(k, k): v for k, v in packet.items()} + packet = Packet({self.key_map.get(k, k): v for k, v in packet.items()}) yield tag, packet return SyncStreamFromGenerator(generator) @@ -861,9 +859,9 @@ def generator() -> Iterator[tuple[Tag, Packet]]: if k not in new_tag: new_tag[k] = [t.get(k, None) for t, _ in packets] # combine all packets into a single packet - combined_packet: Packet = { + combined_packet: Packet = Packet({ k: [p.get(k, None) for _, p in packets] for k in packet_keys - } + }) yield new_tag, combined_packet return SyncStreamFromGenerator(generator) diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index 77d1610..eb880b4 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -8,8 +8,8 @@ ) from orcapod.types import Packet, Tag, TypeSpec, default_registry -from orcapod.types.typespec import extract_function_typespecs -from orcapod.types.registry import PacketConverter +from orcapod.types.typespec_utils import extract_function_typespecs +from orcapod.types.semantic_type_registry import PacketConverter from orcapod.hashing import ( FunctionInfoExtractor, @@ -258,7 +258,7 @@ def call(self, tag, packet) -> tuple[Tag, Packet | None]: f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" ) - output_packet: Packet = {k: v for k, v in zip(self.output_keys, output_values)} + output_packet: Packet = Packet({k: v for k, v in zip(self.output_keys, output_values)}) return tag, output_packet def identity_structure(self, *streams) -> Any: diff --git a/src/orcapod/core/sources.py b/src/orcapod/core/sources.py index 33df20d..21adae9 100644 --- a/src/orcapod/core/sources.py +++ b/src/orcapod/core/sources.py @@ -78,7 +78,7 @@ def forward(self, *streams: SyncStream) -> SyncStream: def generator() -> Iterator[tuple[Tag, Packet]]: for file in Path(self.file_path).glob(self.pattern): - yield self.tag_function(file), {self.name: str(file)} + yield self.tag_function(file), Packet({self.name: str(file)}) return SyncStreamFromGenerator(generator) diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index c70b009..33f6b78 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -1,7 +1,7 @@ from collections.abc import Callable, Collection, Iterator from orcapod.core.base import SyncStream -from orcapod.types import Packet, Tag, TypeSpec +from orcapod.types import Packet, PacketLike, Tag, TypeSpec from copy import copy @@ -9,8 +9,8 @@ class SyncStreamFromLists(SyncStream): def __init__( self, tags: Collection[Tag] | None = None, - packets: Collection[Packet] | None = None, - paired: Collection[tuple[Tag, Packet]] | None = None, + packets: Collection[PacketLike] | None = None, + paired: Collection[tuple[Tag, PacketLike]] | None = None, tag_keys: list[str] | None = None, packet_keys: list[str] | None = None, tag_typespec: TypeSpec | None = None, @@ -33,9 +33,9 @@ def __init__( raise ValueError( "tags and packets must have the same length if both are provided" ) - self.paired = list(zip(tags, packets)) + self.paired = list((t, Packet(v)) for t, v in zip(tags, packets)) elif paired is not None: - self.paired = list(paired) + self.paired = list((t, Packet(v)) for t, v in paired) else: raise ValueError( "Either tags and packets or paired must be provided to SyncStreamFromLists" diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index 2354696..2bdff2b 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -16,7 +16,7 @@ get_default_arrow_hasher, ) from .types import ( - FileHasher, + FileContentHasher, PacketHasher, ArrowHasher, ObjectHasher, @@ -27,7 +27,7 @@ from .content_identifiable import ContentIdentifiableBase __all__ = [ - "FileHasher", + "FileContentHasher", "PacketHasher", "ArrowHasher", "StringCacher", diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index 728b904..c50ebfc 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -4,7 +4,33 @@ import pyarrow.ipc as ipc from io import BytesIO import polars as pl -from .types import SemanticTypeHasher +import json +from orcapod.hashing.types import SemanticTypeHasher, StringCacher + + +def serialize_pyarrow_table(table: pa.Table) -> str: + """ + Serialize a PyArrow table to a stable JSON string by converting to dictionary of lists. + + Args: + table: PyArrow table to serialize + + Returns: + JSON string representation with sorted keys and no whitespace + """ + # Convert table to dictionary of lists using to_pylist() + data_dict = {} + + for column_name in table.column_names: + # Convert Arrow column to Python list, which visits all elements + data_dict[column_name] = table.column(column_name).to_pylist() + + # Serialize to JSON with sorted keys and no whitespace + return json.dumps( + data_dict, + separators=(",", ":"), + sort_keys=True, + ) class SemanticArrowHasher: @@ -18,7 +44,14 @@ class SemanticArrowHasher: 4. Computes final hash of the processed packet """ - def __init__(self, chunk_size: int = 8192, handle_missing: str = "error"): + def __init__( + self, + hasher_id: str, + hash_algorithm: str = "sha256", + chunk_size: int = 8192, + handle_missing: str = "error", + semantic_type_hashers: dict[str, SemanticTypeHasher] | None = None, + ): """ Initialize SemanticArrowHasher. @@ -26,9 +59,28 @@ def __init__(self, chunk_size: int = 8192, handle_missing: str = "error"): chunk_size: Size of chunks to read files in bytes handle_missing: How to handle missing files ('error', 'skip', 'null_hash') """ + self._hasher_id = hasher_id self.chunk_size = chunk_size self.handle_missing = handle_missing - self.semantic_type_hashers: dict[str, SemanticTypeHasher] = {} + self.semantic_type_hashers: dict[str, SemanticTypeHasher] = ( + semantic_type_hashers or {} + ) + self.hash_algorithm = hash_algorithm + + def set_cacher(self, semantic_type: str, cacher: StringCacher) -> None: + """ + Add a string cacher for caching hash values. + + This is a no-op for SemanticArrowHasher since it hashes column contents directly. + """ + # SemanticArrowHasher does not use string caching, so this is a no-op + if semantic_type in self.semantic_type_hashers: + self.semantic_type_hashers[semantic_type].set_cacher(cacher) + else: + raise KeyError(f"No hasher registered for semantic type '{semantic_type}'") + + def get_hasher_id(self) -> str: + return self._hasher_id def register_semantic_hasher(self, semantic_type: str, hasher: SemanticTypeHasher): """Register a custom hasher for a semantic type.""" @@ -117,6 +169,7 @@ def _sort_table_columns(self, table: pa.Table) -> pa.Table: return pa.table(sorted_columns, schema=sorted_schema) def _serialize_table_ipc(self, table: pa.Table) -> bytes: + # TODO: fix and use logical table hashing instead """Serialize table using Arrow IPC format for stable binary representation.""" buffer = BytesIO() @@ -126,13 +179,12 @@ def _serialize_table_ipc(self, table: pa.Table) -> bytes: return buffer.getvalue() - def hash_table(self, table: pa.Table, algorithm: str = "sha256") -> str: + def hash_table(self, table: pa.Table, add_prefix: bool = True) -> str: """ Compute stable hash of Arrow table. Args: table: Arrow table to hash - algorithm: Hash algorithm to use ('sha256', 'md5', etc.) Returns: Hex string of the computed hash @@ -152,14 +204,16 @@ def hash_table(self, table: pa.Table, algorithm: str = "sha256") -> str: serialized_bytes = self._serialize_table_ipc(sorted_table) # Step 4: Compute final hash - hasher = hashlib.new(algorithm) + hasher = hashlib.new(self.hash_algorithm) hasher.update(serialized_bytes) - return hasher.hexdigest() + hash_str = hasher.hexdigest() + if add_prefix: + hash_str = f"{self.get_hasher_id()}:{hash_str}" + + return hash_str - def hash_table_with_metadata( - self, table: pa.Table, algorithm: str = "sha256" - ) -> dict[str, Any]: + def hash_table_with_metadata(self, table: pa.Table) -> dict[str, Any]: """ Compute hash with additional metadata about the process. @@ -180,11 +234,10 @@ def hash_table_with_metadata( processed_columns.append(column_info) # Compute hash - table_hash = self.hash_table(table, algorithm) + table_hash = self.hash_table(table) return { "hash": table_hash, - "algorithm": algorithm, "num_rows": len(table), "num_columns": len(table.schema), "processed_columns": processed_columns, diff --git a/src/orcapod/hashing/arrow_utils.py b/src/orcapod/hashing/arrow_utils.py new file mode 100644 index 0000000..168c53f --- /dev/null +++ b/src/orcapod/hashing/arrow_utils.py @@ -0,0 +1,403 @@ +import pyarrow as pa +import json +import hashlib +from typing import Dict, List, Any +from decimal import Decimal +import base64 + + +def serialize_pyarrow_table_schema(table: pa.Table) -> str: + """ + Serialize PyArrow table schema to JSON with Python type names and filtered metadata. + + Args: + table: PyArrow table + + Returns: + JSON string representation of schema + """ + schema_info = [] + + for field in table.schema: + field_info = { + "name": field.name, + "type": _arrow_type_to_python_type(field.type), + "metadata": _extract_semantic_metadata(field.metadata), + } + schema_info.append(field_info) + + return json.dumps(schema_info, separators=(",", ":"), sort_keys=True) + + +def serialize_pyarrow_table(table: pa.Table) -> str: + """ + Serialize a PyArrow table to a stable JSON string with both schema and data. + + Args: + table: PyArrow table to serialize + + Returns: + JSON string representation with schema and data sections + """ + # Convert table to dictionary of lists using to_pylist() + data_dict = {} + + for column_name in table.column_names: + column = table.column(column_name) + # Convert Arrow column to Python list, which visits all elements + column_values = column.to_pylist() + + # Handle special types that need encoding for JSON + data_dict[column_name] = [ + _serialize_value_for_json(val) for val in column_values + ] + + # Serialize schema + schema_info = [] + for field in table.schema: + field_info = { + "name": field.name, + "type": _arrow_type_to_python_type(field.type), + "metadata": _extract_semantic_metadata(field.metadata), + } + schema_info.append(field_info) + + # Combine schema and data + serialized_table = {"schema": schema_info, "data": data_dict} + + # Serialize to JSON with sorted keys and no whitespace + return json.dumps( + serialized_table, + separators=(",", ":"), + sort_keys=True, + default=_json_serializer, + ) + + +def get_pyarrow_table_hash(table: pa.Table) -> str: + """ + Get a stable SHA-256 hash of the table content. + + Args: + table: PyArrow table + + Returns: + SHA-256 hash of the serialized table + """ + serialized = serialize_pyarrow_table(table) + return hashlib.sha256(serialized.encode("utf-8")).hexdigest() + + +def deserialize_to_pyarrow_table(serialized_str: str) -> pa.Table: + """ + Deserialize JSON string back to a PyArrow table. + + Args: + serialized_str: JSON string from serialize_pyarrow_table + + Returns: + Reconstructed PyArrow table + """ + parsed_data = json.loads(serialized_str) + + # Handle both old format (dict of lists) and new format (schema + data) + if "data" in parsed_data and "schema" in parsed_data: + # New format with schema and data + data_dict = parsed_data["data"] + schema_info = parsed_data["schema"] + else: + # Old format - just data dict + data_dict = parsed_data + schema_info = None + + if not data_dict: + return pa.table([]) + + # Deserialize each column + arrays = [] + names = [] + + for column_name in sorted(data_dict.keys()): # Sort for consistency + column_values = [_deserialize_value(val) for val in data_dict[column_name]] + arrays.append(pa.array(column_values)) + names.append(column_name) + + return pa.table(arrays, names=names) + + +def _arrow_type_to_python_type(arrow_type: pa.DataType) -> str: + """ + Convert PyArrow data type to standard Python type name. + + Args: + arrow_type: PyArrow data type + + Returns: + Python type name as string + """ + if pa.types.is_boolean(arrow_type): + return "bool" + elif pa.types.is_integer(arrow_type): + return "int" + elif pa.types.is_floating(arrow_type): + return "float" + elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return "str" + elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): + return "bytes" + elif pa.types.is_date(arrow_type): + return "date" + elif pa.types.is_timestamp(arrow_type): + return "datetime" + elif pa.types.is_time(arrow_type): + return "time" + elif pa.types.is_decimal(arrow_type): + return "decimal" + elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): + element_type = _arrow_type_to_python_type(arrow_type.value_type) + return f"list[{element_type}]" + elif pa.types.is_struct(arrow_type): + return "dict" + elif pa.types.is_dictionary(arrow_type): + value_type = _arrow_type_to_python_type(arrow_type.value_type) + return value_type # Dictionary encoding is transparent + elif pa.types.is_null(arrow_type): + return "NoneType" + else: + # Fallback for other types + return str(arrow_type).lower() + + +def _extract_semantic_metadata(field_metadata) -> Dict[str, str]: + """ + Extract only 'semantic_type' metadata from field metadata. + + Args: + field_metadata: PyArrow field metadata (can be None) + + Returns: + Dictionary containing only semantic_type if present, empty dict otherwise + """ + if field_metadata is None: + return {} + + metadata_dict = dict(field_metadata) + + # Only keep semantic_type if it exists + if "semantic_type" in metadata_dict: + return { + "semantic_type": metadata_dict["semantic_type"].decode("utf-8") + if isinstance(metadata_dict["semantic_type"], bytes) + else metadata_dict["semantic_type"] + } + else: + return {} + + +def _serialize_value_for_json(value: Any) -> Any: + """ + Prepare a Python value for JSON serialization. + + Args: + value: Python value from to_pylist() + + Returns: + JSON-serializable value + """ + if value is None: + return None + elif isinstance(value, bytes): + return { + "__type__": "bytes", + "__value__": base64.b64encode(value).decode("ascii"), + } + elif isinstance(value, Decimal): + return {"__type__": "decimal", "__value__": str(value)} + elif hasattr(value, "date") and hasattr(value, "time"): # datetime objects + return {"__type__": "datetime", "__value__": value.isoformat()} + elif hasattr(value, "isoformat") and not hasattr( + value, "time" + ): # date objects (no time component) + return {"__type__": "date", "__value__": value.isoformat()} + elif isinstance(value, (list, tuple)): + return [_serialize_value_for_json(item) for item in value] + elif isinstance(value, dict): + return {k: _serialize_value_for_json(v) for k, v in sorted(value.items())} + else: + return value + + +def _deserialize_value(value: Any) -> Any: + """ + Deserialize a value from the JSON representation. + + Args: + value: Value from JSON + + Returns: + Python value suitable for PyArrow + """ + if value is None: + return None + elif isinstance(value, dict) and "__type__" in value: + type_name = value["__type__"] + val = value["__value__"] + + if type_name == "bytes": + return base64.b64decode(val.encode("ascii")) + elif type_name == "decimal": + return Decimal(val) + elif type_name == "datetime": + from datetime import datetime + + return datetime.fromisoformat(val) + elif type_name == "date": + from datetime import date + + return date.fromisoformat(val) + else: + return val + elif isinstance(value, list): + return [_deserialize_value(item) for item in value] + elif isinstance(value, dict): + return {k: _deserialize_value(v) for k, v in value.items()} + else: + return value + + +def _json_serializer(obj): + """Custom JSON serializer for edge cases.""" + if hasattr(obj, "date") and hasattr(obj, "time"): # datetime objects + return {"__type__": "datetime", "__value__": obj.isoformat()} + elif hasattr(obj, "isoformat") and not hasattr(obj, "time"): # date objects + return {"__type__": "date", "__value__": obj.isoformat()} + elif isinstance(obj, bytes): + return {"__type__": "bytes", "__value__": base64.b64encode(obj).decode("ascii")} + elif isinstance(obj, Decimal): + return {"__type__": "decimal", "__value__": str(obj)} + else: + return str(obj) # Fallback to string representation + + +# Example usage and testing +if __name__ == "__main__": + import datetime + + # Create a sample PyArrow table with various types + data = { + "integers": [1, 2, 3, 4, 5], + "floats": [1.1, 2.2, 3.3, 4.4, 5.5], + "strings": ["a", "b", "c", "d", "e"], + "booleans": [True, False, True, False, True], + "nulls": [1, None, 3, None, 5], + "dates": [ + datetime.date(2023, 1, 1), + datetime.date(2023, 1, 2), + None, + datetime.date(2023, 1, 4), + datetime.date(2023, 1, 5), + ], + } + + table = pa.table(data) + print("Original table:") + print(table) + print() + + # Serialize the table + serialized = serialize_pyarrow_table(table) + print("Serialized JSON (first 200 chars):") + print(serialized[:200] + "..." if len(serialized) > 200 else serialized) + print() + + # Get hash + table_hash = get_pyarrow_table_hash(table) + print(f"Table hash: {table_hash}") + print() + + # Test stability + serialized2 = serialize_pyarrow_table(table) + hash2 = get_pyarrow_table_hash(table) + + print(f"Serialization is stable: {serialized == serialized2}") + print(f"Hash is stable: {table_hash == hash2}") + print() + + # Test with different column order + print("--- Testing column order stability ---") + data_reordered = { + "strings": ["a", "b", "c", "d", "e"], + "integers": [1, 2, 3, 4, 5], + "nulls": [1, None, 3, None, 5], + "floats": [1.1, 2.2, 3.3, 4.4, 5.5], + "booleans": [True, False, True, False, True], + "dates": [ + datetime.date(2023, 1, 1), + datetime.date(2023, 1, 2), + None, + datetime.date(2023, 1, 4), + datetime.date(2023, 1, 5), + ], + } + + table_reordered = pa.table(data_reordered) + serialized_reordered = serialize_pyarrow_table(table_reordered) + hash_reordered = get_pyarrow_table_hash(table_reordered) + + print( + f"Same content, different column order produces same serialization: {serialized == serialized_reordered}" + ) + print( + f"Same content, different column order produces same hash: {table_hash == hash_reordered}" + ) + print() + + # Test schema serialization + print("\n--- Testing schema serialization ---") + + # Create table with metadata + schema = pa.schema( + [ + pa.field( + "integers", + pa.int64(), + metadata={"semantic_type": "id", "other_meta": "ignored"}, + ), + pa.field("floats", pa.float64(), metadata={"semantic_type": "measurement"}), + pa.field("strings", pa.string()), # No metadata + pa.field( + "booleans", pa.bool_(), metadata={"other_meta": "ignored"} + ), # No semantic_type + pa.field("dates", pa.date32(), metadata={"semantic_type": "event_date"}), + ] + ) + + table_with_schema = pa.table(data, schema=schema) + schema_json = serialize_pyarrow_table_schema(table_with_schema) + print(f"Schema JSON: {schema_json}") + + # Parse and display nicely + import json as json_module + + schema_parsed = json_module.loads(schema_json) + print("\nParsed schema:") + for field in schema_parsed: + print(f" {field['name']}: {field['type']} (metadata: {field['metadata']})") + + # Test deserialization + reconstructed = deserialize_to_pyarrow_table(serialized) + print("Reconstructed table:") + print(reconstructed) + print() + + # Verify round-trip + reconstructed_hash = get_pyarrow_table_hash(reconstructed) + print(f"Round-trip hash matches: {table_hash == reconstructed_hash}") + + # Show actual JSON structure for small example + print("\n--- Small example JSON structure ---") + small_table = pa.table( + {"numbers": [1, 2, None], "text": ["hello", "world", "test"]} + ) + small_json = serialize_pyarrow_table(small_table) + print(f"Small table JSON: {small_json}") diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 1f3aca2..61539b5 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -1,6 +1,11 @@ # A collection of utility function that provides a "default" implementation of hashers. # This is often used as the fallback hasher in the library code. -from orcapod.hashing.types import CompositeFileHasher, ArrowHasher, FileHasher +from orcapod.hashing.types import ( + CompositeFileHasher, + ArrowHasher, + FileContentHasher, + StringCacher, +) from orcapod.hashing.file_hashers import BasicFileHasher, LegacyPathLikeHasherFactory from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.hashing.object_hashers import ObjectHasher @@ -8,20 +13,41 @@ from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory from orcapod.hashing.arrow_hashers import SemanticArrowHasher from orcapod.hashing.semantic_type_hashers import PathHasher +from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher + + +def get_default_arrow_hasher( + cache_file_hash: bool | StringCacher = True, +) -> ArrowHasher: + """ + Get the default Arrow hasher with semantic type support. + If `with_cache` is True, it uses an in-memory cacher for caching hash values. + """ + arrow_hasher = get_versioned_semantic_arrow_hasher() + if cache_file_hash: + # use unlimited caching + if isinstance(cache_file_hash, StringCacher): + string_cacher = cache_file_hash + else: + string_cacher = InMemoryCacher(max_size=None) + + arrow_hasher.set_cacher("path", string_cacher) + + return arrow_hasher def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: if with_cache: # use unlimited caching string_cacher = InMemoryCacher(max_size=None) - return LegacyPathLikeHasherFactory.create_cached_composite(string_cacher) - return LegacyPathLikeHasherFactory.create_basic_composite() + return LegacyPathLikeHasherFactory.create_cached_legacy_composite(string_cacher) + return LegacyPathLikeHasherFactory.create_basic_legacy_composite() def get_default_composite_file_hasher_with_cacher(cacher=None) -> CompositeFileHasher: if cacher is None: cacher = InMemoryCacher(max_size=None) - return LegacyPathLikeHasherFactory.create_cached_composite(cacher) + return LegacyPathLikeHasherFactory.create_cached_legacy_composite(cacher) def get_default_object_hasher() -> ObjectHasher: @@ -40,16 +66,3 @@ def get_legacy_object_hasher() -> ObjectHasher: ) ) return LegacyObjectHasher(function_info_extractor=function_info_extractor) - - -def get_default_arrow_hasher( - chunk_size: int = 8192, - handle_missing: str = "error", - file_hasher: FileHasher | None = None, -) -> ArrowHasher: - if file_hasher is None: - file_hasher = BasicFileHasher() - hasher = SemanticArrowHasher(chunk_size=chunk_size, handle_missing=handle_missing) - # register semantic hasher for Path - hasher.register_semantic_hasher("Path", PathHasher(file_hasher=file_hasher)) - return hasher diff --git a/src/orcapod/hashing/file_hashers.py b/src/orcapod/hashing/file_hashers.py index 58076ac..cd12e80 100644 --- a/src/orcapod/hashing/file_hashers.py +++ b/src/orcapod/hashing/file_hashers.py @@ -1,7 +1,7 @@ from orcapod.hashing import legacy_core from orcapod.hashing.hash_utils import hash_file from orcapod.hashing.types import ( - FileHasher, + FileContentHasher, PathSetHasher, StringCacher, CompositeFileHasher, @@ -31,7 +31,7 @@ class CachedFileHasher: def __init__( self, - file_hasher: FileHasher, + file_hasher: FileContentHasher, string_cacher: StringCacher, ): self.file_hasher = file_hasher @@ -73,7 +73,7 @@ class LegacyPathsetHasher: def __init__( self, - file_hasher: FileHasher, + file_hasher: FileContentHasher, char_count: int | None = 32, ): self.file_hasher = file_hasher @@ -126,7 +126,7 @@ class LegacyCompositeFileHasher: def __init__( self, - file_hasher: FileHasher, + file_hasher: FileContentHasher, char_count: int | None = 32, packet_prefix: str = "", ): @@ -151,27 +151,27 @@ class LegacyPathLikeHasherFactory: """Factory for creating various hasher combinations.""" @staticmethod - def create_basic_composite( + def create_basic_legacy_composite( algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, ) -> CompositeFileHasher: """Create a basic composite hasher.""" - file_hasher = BasicFileHasher(algorithm, buffer_size) + file_hasher = LegacyFileHasher(algorithm, buffer_size) # use algorithm as the prefix for the packet hasher return LegacyCompositeFileHasher( file_hasher, char_count, packet_prefix=algorithm ) @staticmethod - def create_cached_composite( + def create_cached_legacy_composite( string_cacher: StringCacher, algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, ) -> CompositeFileHasher: """Create a composite hasher with file caching.""" - basic_file_hasher = BasicFileHasher(algorithm, buffer_size) + basic_file_hasher = LegacyFileHasher(algorithm, buffer_size) cached_file_hasher = CachedFileHasher(basic_file_hasher, string_cacher) return LegacyCompositeFileHasher( cached_file_hasher, char_count, packet_prefix=algorithm @@ -182,7 +182,7 @@ def create_file_hasher( string_cacher: StringCacher | None = None, algorithm: str = "sha256", buffer_size: int = 65536, - ) -> FileHasher: + ) -> FileContentHasher: """Create just a file hasher, optionally with caching.""" basic_hasher = BasicFileHasher(algorithm, buffer_size) if string_cacher is None: diff --git a/src/orcapod/hashing/semantic_type_hashers.py b/src/orcapod/hashing/semantic_type_hashers.py index 36dfd53..5be28b0 100644 --- a/src/orcapod/hashing/semantic_type_hashers.py +++ b/src/orcapod/hashing/semantic_type_hashers.py @@ -1,4 +1,4 @@ -from .types import SemanticTypeHasher, FileHasher +from orcapod.hashing.types import SemanticTypeHasher, FileContentHasher, StringCacher import os import hashlib import pyarrow as pa @@ -7,7 +7,13 @@ class PathHasher(SemanticTypeHasher): """Hasher for Path semantic type columns - hashes file contents.""" - def __init__(self, file_hasher: FileHasher, handle_missing: str = "error"): + def __init__( + self, + file_hasher: FileContentHasher, + handle_missing: str = "error", + string_cacher: StringCacher | None = None, + cache_key_prefix: str = "path_hasher", + ): """ Initialize PathHasher. @@ -17,11 +23,20 @@ def __init__(self, file_hasher: FileHasher, handle_missing: str = "error"): """ self.file_hasher = file_hasher self.handle_missing = handle_missing + self.cacher = string_cacher + self.cache_key_prefix = cache_key_prefix def _hash_file_content(self, file_path: str) -> str: """Hash the content of a single file and return hex string.""" import os + # if cacher exists, check if the hash is cached + if self.cacher: + cache_key = f"{self.cache_key_prefix}:{file_path}" + cached_hash = self.cacher.get_cached(cache_key) + if cached_hash is not None: + return cached_hash + try: if not os.path.exists(file_path): if self.handle_missing == "error": @@ -31,7 +46,13 @@ def _hash_file_content(self, file_path: str) -> str: elif self.handle_missing == "null_hash": return hashlib.sha256(b"").hexdigest() - return self.file_hasher.hash_file(file_path).hex() + hashed_value = self.file_hasher.hash_file(file_path).hex() + if self.cacher: + # Cache the computed hash + self.cacher.set_cached( + f"{self.cache_key_prefix}:{file_path}", hashed_value + ) + return hashed_value except (IOError, OSError, PermissionError) as e: if self.handle_missing == "error": @@ -62,3 +83,11 @@ def hash_column(self, column: pa.Array) -> pa.Array: # Return new array with content hashes instead of paths return pa.array(content_hashes) + + def set_cacher(self, cacher: StringCacher) -> None: + """ + Add a string cacher for caching hash values. + This is a no-op for PathHasher since it hashes file contents directly. + """ + # PathHasher does not use string caching, so this is a no-op + self.cacher = cacher diff --git a/src/orcapod/hashing/string_cachers.py b/src/orcapod/hashing/string_cachers.py index 9b2244a..620dece 100644 --- a/src/orcapod/hashing/string_cachers.py +++ b/src/orcapod/hashing/string_cachers.py @@ -13,10 +13,12 @@ if TYPE_CHECKING: import redis + def _get_redis(): """Lazy import for Redis to avoid circular dependencies.""" try: import redis + return redis except ImportError as e: return None @@ -615,7 +617,9 @@ def __init__( # TODO: cleanup the redis use pattern self._redis_module = _get_redis() if self._redis_module is None: - raise ImportError("Could not import Redis module. redis package is required for RedisCacher") + raise ImportError( + "Could not import Redis module. redis package is required for RedisCacher" + ) self.key_prefix = key_prefix self._connection_failed = False self._lock = threading.RLock() @@ -658,7 +662,10 @@ def _test_connection(self) -> None: f"Redis connection established successfully with prefix '{self.key_prefix}'" ) - except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: + except ( + self._redis_module.RedisError, + self._redis_module.ConnectionError, + ) as e: logging.error(f"Failed to establish Redis connection: {e}") raise RuntimeError(f"Redis connection test failed: {e}") @@ -690,7 +697,10 @@ def get_cached(self, cache_key: str) -> str | None: return str(result) - except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: + except ( + self._redis_module.RedisError, + self._redis_module.ConnectionError, + ) as e: self._handle_redis_error("get", e) return None @@ -708,7 +718,10 @@ def set_cached(self, cache_key: str, value: str) -> None: self.redis.set(self._get_prefixed_key(cache_key), value) - except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: + except ( + self._redis_module.RedisError, + self._redis_module.ConnectionError, + ) as e: self._handle_redis_error("set", e) def clear_cache(self) -> None: @@ -722,7 +735,10 @@ def clear_cache(self) -> None: if keys: self.redis.delete(*list(keys)) # type: ignore[arg-type] - except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: + except ( + self._redis_module.RedisError, + self._redis_module.ConnectionError, + ) as e: self._handle_redis_error("clear", e) def is_connected(self) -> bool: diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index 310b5a2..10ed267 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -79,7 +79,7 @@ def hash_to_uuid( @runtime_checkable -class FileHasher(Protocol): +class FileContentHasher(Protocol): """Protocol for file-related hashing.""" def hash_file(self, file_path: PathLike) -> bytes: ... @@ -104,7 +104,7 @@ def hash_packet(self, packet: Packet) -> str: ... class ArrowHasher(Protocol): """Protocol for hashing arrow packets.""" - def hash_table(self, table: pa.Table) -> str: ... + def hash_table(self, table: pa.Table, add_prefix: bool = True) -> str: ... @runtime_checkable @@ -118,7 +118,7 @@ def clear_cache(self) -> None: ... # Combined interface for convenience (optional) @runtime_checkable -class CompositeFileHasher(FileHasher, PathSetHasher, PacketHasher, Protocol): +class CompositeFileHasher(FileContentHasher, PathSetHasher, PacketHasher, Protocol): """Combined interface for all file-related hashing operations.""" pass @@ -142,6 +142,14 @@ class SemanticTypeHasher(Protocol): """Abstract base class for semantic type-specific hashers.""" @abstractmethod - def hash_column(self, column: pa.Array) -> list[bytes]: + def hash_column( + self, + column: pa.Array, + ) -> pa.Array: """Hash a column with this semantic type and return the hash bytes.""" pass + + @abstractmethod + def set_cacher(self, cacher: StringCacher) -> None: + """Add a string cacher for caching hash values.""" + pass diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py new file mode 100644 index 0000000..22c715e --- /dev/null +++ b/src/orcapod/hashing/versioned_hashers.py @@ -0,0 +1,71 @@ +# A collection of versioned hashers that provide a "default" implementation of hashers. +from .arrow_hashers import SemanticArrowHasher +import importlib +from typing import Any + +CURRENT_VERSION = "v0.1" + +versioned_hashers = { + "v0.1": { + "_class": "orcapod.hashing.arrow_hashers.SemanticArrowHasher", + "config": { + "hasher_id": "default_v0.1", + "hash_algorithm": "sha256", + "chunk_size": 8192, + "semantic_type_hashers": { + "path": { + "_class": "orcapod.hashing.semantic_type_hashers.PathHasher", + "config": { + "file_hasher": { + "_class": "orcapod.hashing.file_hashers.BasicFileHasher", + "config": { + "algorithm": "sha256", + }, + } + }, + } + }, + }, + } +} + + +def parse_objectspec(obj_spec: dict) -> Any: + if "_class" in obj_spec: + # if _class is specified, treat the dict as an object specification + module_name, class_name = obj_spec["_class"].rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + configs = parse_objectspec(obj_spec.get("config", {})) + return cls(**configs) + else: + # otherwise, parse through the dictionary recursively + parsed_object = obj_spec + for k, v in obj_spec.items(): + if isinstance(v, dict): + parsed_object[k] = parse_objectspec(v) + else: + parsed_object[k] = v + return parsed_object + + +def get_versioned_semantic_arrow_hasher( + version: str | None = None, +) -> SemanticArrowHasher: + """ + Get the versioned hasher for the specified version. + + Args: + version (str): The version of the hasher to retrieve. + + Returns: + SemanticArrowHasher: An instance of the hasher for the specified version. + """ + if version is None: + version = CURRENT_VERSION + + if version not in versioned_hashers: + raise ValueError(f"Unsupported hasher version: {version}") + + hasher_spec = versioned_hashers[version] + return parse_objectspec(hasher_spec) diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index e953f1f..4396223 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -2,69 +2,23 @@ from orcapod.core import SyncStream, Source, Kernel from orcapod.store import ArrowDataStore from orcapod.types import Tag, Packet, TypeSpec, default_registry -from orcapod.types.typespec import extract_function_typespecs +from orcapod.types.typespec_utils import get_typespec_from_dict, union_typespecs, extract_function_typespecs +from orcapod.types.semantic_type_registry import create_arrow_table_with_meta from orcapod.hashing import ObjectHasher, ArrowHasher from orcapod.hashing.defaults import get_default_object_hasher, get_default_arrow_hasher from typing import Any, Literal from collections.abc import Collection, Iterator -from orcapod.types.registry import TypeRegistry, PacketConverter +from orcapod.types.semantic_type_registry import TypeRegistry +from orcapod.types.packet_converter import PacketConverter import pyarrow as pa import polars as pl from orcapod.core.streams import SyncStreamFromGenerator -from orcapod.utils.stream_utils import get_typespec, union_typespecs import logging logger = logging.getLogger(__name__) -def tag_to_arrow_table_with_metadata(tag, metadata: dict | None = None): - """ - Convert a tag dictionary to PyArrow table with metadata on each column. - - Args: - tag: Dictionary with string keys and any Python data type values - metadata_key: The metadata key to add to each column - metadata_value: The metadata value to indicate this column came from tag - """ - if metadata is None: - metadata = {} - - # First create the table to infer types - temp_table = pa.Table.from_pylist([tag]) - - # Create new fields with metadata - fields_with_metadata = [] - for field in temp_table.schema: - # Add metadata to each field - field_metadata = metadata - new_field = pa.field( - field.name, field.type, nullable=field.nullable, metadata=field_metadata - ) - fields_with_metadata.append(new_field) - - # Create schema with metadata - schema_with_metadata = pa.schema(fields_with_metadata) - - # Create the final table with the metadata-enriched schema - table = pa.Table.from_pylist([tag], schema=schema_with_metadata) - - return table - - -def get_columns_with_metadata( - df: pl.DataFrame, key: str, value: str | None = None -) -> list[str]: - """Get column names with specific metadata using list comprehension. If value is given, only - columns matching that specific value for the desginated metadata key will be returned. - Otherwise, all columns that contains the key as metadata will be returned regardless of the value""" - return [ - col_name - for col_name, dtype in df.schema.items() - if hasattr(dtype, "metadata") - and (value is None or getattr(dtype, "metadata") == value) - ] - class PolarsSource(Source): def __init__(self, df: pl.DataFrame, tag_keys: Collection[str] | None = None): @@ -81,18 +35,15 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: class PolarsStream(SyncStream): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str] | None = None): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]): self.df = df - if tag_keys is None: - # extract tag_keys by picking columns with metadata source=tag - tag_keys = get_columns_with_metadata(df, "source", "tag") self.tag_keys = tag_keys def __iter__(self) -> Iterator[tuple[Tag, Packet]]: for row in self.df.iter_rows(named=True): tag = {key: row[key] for key in self.tag_keys} packet = {key: val for key, val in row.items() if key not in self.tag_keys} - yield tag, packet + yield tag, Packet(packet) class EmptyStream(SyncStream): @@ -266,26 +217,44 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: output_stream = self.kernel.forward(*resolved_streams, **kwargs) - tag_type, packet_type = output_stream.types(trigger_run=False) - if tag_type is not None and packet_type is not None: - joined_type = union_typespecs(tag_type, packet_type) + tag_typespec, packet_typespec = output_stream.types(trigger_run=False) + if tag_typespec is not None and packet_typespec is not None: + joined_type = union_typespecs(tag_typespec, packet_typespec) assert joined_type is not None, "Joined typespec should not be None" - self.output_converter = PacketConverter(joined_type, registry=self.registry) + all_type = dict(joined_type) + for k in packet_typespec: + all_type[f'_source_{k}'] = str + # + self.output_converter = PacketConverter(all_type, registry=self.registry) # Cache the output stream of the underlying kernel - # This is a no-op if the output stream is already cached + # If an entry with same tag and packet already exists in the output store, + # it will not be added again, thus avoiding duplicates. def generator() -> Iterator[tuple[Tag, Packet]]: logger.info(f"Computing and caching outputs for {self}") for tag, packet in output_stream: merged_info = {**tag, **packet} + # add entries for source_info + for k, v in packet.source_info.items(): + merged_info[f'_source_{k}'] = v + if self.output_converter is None: - joined_type = get_typespec(merged_info) + # TODO: cleanup logic here + joined_type = get_typespec_from_dict(merged_info) assert joined_type is not None, "Joined typespec should not be None" + all_type = dict(joined_type) + for k in packet: + all_type[f'_source_{k}'] = str self.output_converter = PacketConverter( - joined_type, registry=self.registry + all_type, registry=self.registry ) + # add entries for source_info + for k, v in packet.source_info.items(): + merged_info[f'_source_{k}'] = v + output_table = self.output_converter.to_arrow_table(merged_info) + # TODO: revisit this logic output_id = self.arrow_hasher.hash_table(output_table) if not self.output_store.get_record(*self.source_info, output_id): self.output_store.add_record( @@ -463,7 +432,6 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: return super().forward(*streams, **kwargs) def get_packet_key(self, packet: Packet) -> str: - # TODO: reconsider the logic around input/output converter -- who should own this? return self.arrow_hasher.hash_table(self.input_converter.to_arrow_table(packet)) @property @@ -473,23 +441,25 @@ def source_info(self): def is_memoized(self, packet: Packet) -> bool: return self.retrieve_memoized(packet) is not None - def add_tag_record(self, tag: Tag, packet: Packet) -> Tag: + def add_pipeline_record(self, tag: Tag, packet: Packet) -> Tag: """ Record the tag for the packet in the record store. This is used to keep track of the tags associated with memoized packets. """ - return self._add_tag_record_with_packet_key(tag, self.get_packet_key(packet)) + return self._add_pipeline_record_with_packet_key(tag, self.get_packet_key(packet), packet.source_info) - def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: + def _add_pipeline_record_with_packet_key(self, tag: Tag, packet_key: str, packet_source_info: dict[str, str | None]) -> Tag: if self.tag_store is None: raise ValueError("Recording of tag requires tag_store but none provided") - tag = dict(tag) # ensure we don't modify the original tag - tag["__packet_key"] = packet_key + combined_info = dict(tag) # ensure we don't modify the original tag + combined_info["__packet_key"] = packet_key + for k, v in packet_source_info.items(): + combined_info[f'__{k}_source'] = v # TODO: consider making this more efficient # convert tag to arrow table - columns are labeled with metadata source=tag - table = tag_to_arrow_table_with_metadata(tag, {"source": "tag"}) + table = create_arrow_table_with_meta(combined_info, {"source": "tag"}) entry_hash = self.arrow_hasher.hash_table(table) @@ -553,8 +523,7 @@ def _memoize_with_packet_key( # consider simpler alternative packets = self.output_converter.from_arrow_table( self.output_store.add_record( - self.function_pod.function_name, - self.function_pod_hash, + *self.source_info, packet_key, self.output_converter.to_arrow_table(output_packet), ) @@ -563,7 +532,13 @@ def _memoize_with_packet_key( assert len(packets) == 1, ( f"Memoizing single packet returned {len(packets)} packets!" ) - return packets[0] + packet = packets[0] + # TODO: reconsider the right place to attach this information + # attach provenance information + packet_source_id = ":".join(self.source_info + (packet_key,)) + source_info = {k: f'{packet_source_id}:{k}' for k in packet} + return Packet(packet, source_info=source_info) + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: packet_key = "" @@ -603,7 +578,7 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: # result was successfully computed -- save the tag if not self.skip_tag_record and self.tag_store is not None: - self._add_tag_record_with_packet_key(tag, packet_key) + self._add_pipeline_record_with_packet_key(tag, packet_key, packet.source_info) return tag, output_packet diff --git a/src/orcapod/types/__init__.py b/src/orcapod/types/__init__.py index e51a6f8..a4615f5 100644 --- a/src/orcapod/types/__init__.py +++ b/src/orcapod/types/__init__.py @@ -1,12 +1,13 @@ -from .core import Tag, Packet, TypeSpec, PathLike, PathSet, PodFunction -from .registry import TypeRegistry -from .handlers import PathHandler, UUIDHandler, DateTimeHandler -from . import handlers -from . import typespec +from .core import Tag, PathLike, PathSet, PodFunction, TypeSpec +from .packets import Packet +from .semantic_type_registry import SemanticTypeRegistry +from .semantic_type_handlers import PathHandler, UUIDHandler, DateTimeHandler +from . import semantic_type_handlers +from . import typespec_utils # Create default registry and register handlers -default_registry = TypeRegistry() +default_registry = SemanticTypeRegistry() # Register with semantic names - registry extracts supported types automatically default_registry.register("path", PathHandler()) @@ -19,10 +20,11 @@ "default_registry", "Tag", "Packet", + "PacketLike" "TypeSpec", "PathLike", "PathSet", "PodFunction", - "handlers", - "typespec", + "semantic_type_handlers", + "typespec_utils", ] diff --git a/src/orcapod/types/core.py b/src/orcapod/types/core.py index 097750e..dd02141 100644 --- a/src/orcapod/types/core.py +++ b/src/orcapod/types/core.py @@ -1,19 +1,10 @@ -from typing import Protocol, Any, TypeAlias +from typing import Protocol, Any, TypeAlias, TypeVar, Generic import pyarrow as pa from dataclasses import dataclass import os from collections.abc import Collection, Mapping -# TODO: reconsider the need for this dataclass as its information is superfluous -# to the registration of the handler into the registry. -@dataclass -class TypeInfo: - python_type: type - arrow_type: pa.DataType - semantic_type: str | None # name under which the type is registered - handler: "TypeHandler" - DataType: TypeAlias = type @@ -22,8 +13,6 @@ class TypeInfo: ] # Mapping of parameter names to their types -SUPPORTED_PYTHON_TYPES = (str, int, float, bool, bytes) - # Convenience alias for anything pathlike PathLike = str | os.PathLike @@ -45,14 +34,8 @@ class TypeInfo: # Extended data values that can be stored in packets # Either the original PathSet or one of our supported simple data types -DataValue: TypeAlias = PathSet | SupportedNativePythonData | Collection["DataValue"] - - -# a packet is a mapping from string keys to data values -Packet: TypeAlias = Mapping[str, DataValue] +DataValue: TypeAlias = PathSet | SupportedNativePythonData | None | Collection["DataValue"] -# a batch is a tuple of a tag and a list of packets -Batch: TypeAlias = tuple[Tag, Collection[Packet]] class PodFunction(Protocol): @@ -68,7 +51,7 @@ def __call__(self, **kwargs: DataValue) -> None | DataValue | list[DataValue]: . class TypeHandler(Protocol): - """Protocol for handling conversion between Python types and underlying Arrow + """Protocol for handling conversion between Python type and Arrow data types used for storage. The handler itself IS the definition of a semantic type. The semantic type @@ -78,11 +61,11 @@ class TypeHandler(Protocol): and focus purely on conversion logic. """ - def python_types(self) -> type | tuple[type, ...]: + def python_type(self) -> type: """Return the Python type(s) this handler can process. Returns: - Single Type or tuple of Types this handler supports + Python type the handler supports Examples: - PathHandler: return Path @@ -91,7 +74,7 @@ def python_types(self) -> type | tuple[type, ...]: """ ... - def storage_type(self) -> pa.DataType: + def storage_type(self) -> type: """Return the Arrow DataType instance for schema definition.""" ... diff --git a/src/orcapod/types/packet_converter.py b/src/orcapod/types/packet_converter.py new file mode 100644 index 0000000..0a8389d --- /dev/null +++ b/src/orcapod/types/packet_converter.py @@ -0,0 +1,177 @@ +from orcapod.types.core import TypeSpec, TypeHandler +from orcapod.types.packets import Packet, PacketLike +from orcapod.types.semantic_type_registry import SemanticTypeRegistry, TypeInfo, get_metadata_from_schema, arrow_to_dicts +from typing import Any +from collections.abc import Mapping, Sequence +import pyarrow as pa +import logging + +logger = logging.getLogger(__name__) + + +def is_packet_supported( + python_type_info: TypeSpec, registry: SemanticTypeRegistry, type_lut: dict | None = None +) -> bool: + """Check if all types in the packet are supported by the registry or known to the default lut.""" + if type_lut is None: + type_lut = {} + return all( + python_type in registry or python_type in type_lut + for python_type in python_type_info.values() + ) + + + +class PacketConverter: + def __init__(self, python_type_spec: TypeSpec, registry: SemanticTypeRegistry): + self.python_type_spec = python_type_spec + self.registry = registry + + # Lookup handlers and type info for fast access + self.handlers: dict[str, TypeHandler] = {} + self.storage_type_info: dict[str, TypeInfo] = {} + + self.expected_key_set = set(python_type_spec.keys()) + + # prepare the corresponding arrow table schema with metadata + self.keys_with_handlers, self.schema = create_schema_from_python_type_info( + python_type_spec, registry + ) + + self.semantic_type_lut = get_metadata_from_schema(self.schema, b"semantic_type") + + def _check_key_consistency(self, keys): + """Check if the provided keys match the expected keys.""" + keys_set = set(keys) + if keys_set != self.expected_key_set: + missing_keys = self.expected_key_set - keys_set + extra_keys = keys_set - self.expected_key_set + error_parts = [] + if missing_keys: + error_parts.append(f"Missing keys: {missing_keys}") + if extra_keys: + error_parts.append(f"Extra keys: {extra_keys}") + + raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") + + def _to_storage_packet(self, packet: PacketLike) -> dict[str, Any]: + """Convert packet to storage representation. + + Args: + packet: Dictionary mapping parameter names to Python values + + Returns: + Dictionary with same keys but values converted to storage format + + Raises: + KeyError: If packet keys don't match the expected type_info keys + TypeError: If value type doesn't match expected type + ValueError: If conversion fails + """ + # Validate packet keys + packet_keys = set(packet.keys()) + + self._check_key_consistency(packet_keys) + + # Convert each value + storage_packet: dict[str, Any] = dict(packet) # Start with a copy of the packet + + for key, handler in self.keys_with_handlers: + try: + storage_packet[key] = handler.python_to_storage(storage_packet[key]) + except Exception as e: + raise ValueError(f"Failed to convert value for '{key}': {e}") from e + + return storage_packet + + def _from_storage_packet(self, storage_packet: Mapping[str, Any]) -> PacketLike: + """Convert storage packet back to Python packet. + + Args: + storage_packet: Dictionary with values in storage format + + Returns: + Packet with values converted back to Python types + + Raises: + KeyError: If storage packet keys don't match the expected type_info keys + TypeError: If value type doesn't match expected type + ValueError: If conversion fails + """ + # Validate storage packet keys + storage_keys = set(storage_packet.keys()) + + self._check_key_consistency(storage_keys) + + # Convert each value back to Python type + packet: PacketLike = dict(storage_packet) + + for key, handler in self.keys_with_handlers: + try: + packet[key] = handler.storage_to_python(storage_packet[key]) + except Exception as e: + raise ValueError(f"Failed to convert value for '{key}': {e}") from e + + return packet + + def to_arrow_table(self, packet: PacketLike | Sequence[PacketLike]) -> pa.Table: + """Convert packet to PyArrow Table with field metadata. + + Args: + packet: Dictionary mapping parameter names to Python values + + Returns: + PyArrow Table with the packet data as a single row + """ + # Convert packet to storage format + if not isinstance(packet, Sequence): + packets = [packet] + else: + packets = packet + + storage_packets = [self._to_storage_packet(p) for p in packets] + + # Create arrays + arrays = [] + for field in self.schema: + values = [p[field.name] for p in storage_packets] + array = pa.array(values, type=field.type) + arrays.append(array) + + return pa.Table.from_arrays(arrays, schema=self.schema) + + def from_arrow_table( + self, table: pa.Table, verify_semantic_equivalence: bool = True + ) -> list[Packet]: + """Convert Arrow table to packet with field metadata. + + Args: + table: PyArrow Table with metadata + + Returns: + List of packets converted from the Arrow table + """ + # Check for consistency in the semantic type mapping: + semantic_type_info = get_metadata_from_schema(table.schema, b"semantic_type") + + if semantic_type_info != self.semantic_type_lut: + if not verify_semantic_equivalence: + logger.warning( + "Arrow table semantic types do not match expected type registry. " + f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" + ) + else: + raise ValueError( + "Arrow table semantic types do not match expected type registry. " + f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" + ) + + # Create packets from the Arrow table + # TODO: make this more efficient + storage_packets: list[Packet] = arrow_to_dicts(table) # type: ignore + if not self.keys_with_handlers: + # no special handling required + return storage_packets + + return [Packet(self._from_storage_packet(packet)) for packet in storage_packets] + diff --git a/src/orcapod/types/packets.py b/src/orcapod/types/packets.py new file mode 100644 index 0000000..a8d8775 --- /dev/null +++ b/src/orcapod/types/packets.py @@ -0,0 +1,241 @@ +from orcapod.types.core import DataValue +from typing import TypeAlias, Any +from collections.abc import Mapping, Collection +from orcapod.types.core import TypeSpec, Tag, TypeHandler +from orcapod.types.semantic_type_registry import SemanticTypeRegistry +from orcapod.types import schemas +import pyarrow as pa + +# # a packet is a mapping from string keys to data values +PacketLike: TypeAlias = Mapping[str, DataValue] + + +class Packet(dict[str, DataValue]): + def __init__( + self, + obj: PacketLike | None = None, + typespec: TypeSpec | None = None, + source_info: dict[str, str|None] | None = None + ): + if obj is None: + obj = {} + super().__init__(obj) + if typespec is None: + from orcapod.types.typespec_utils import get_typespec_from_dict + typespec = get_typespec_from_dict(self) + self._typespec = typespec + if source_info is None: + source_info = {} + self._source_info = source_info + + @property + def typespec(self) -> TypeSpec: + # consider returning a copy for immutability + return self._typespec + + @property + def source_info(self) -> dict[str, str | None]: + return {key: self._source_info.get(key, None) for key in self.keys()} + + + +# a batch is a tuple of a tag and a list of packets +Batch: TypeAlias = tuple[Tag, Collection[Packet]] + + +class SemanticPacket(dict[str, Any]): + """ + A packet that conforms to a semantic schema, mapping string keys to values. + + This is used to represent data packets in OrcaPod with semantic types. + + Attributes + ---------- + keys : str + The keys of the packet. + values : Any + The values corresponding to each key. + + Examples + -------- + >>> packet = SemanticPacket(name='Alice', age=30) + >>> print(packet) + {'name': 'Alice', 'age': 30} + """ + def __init__(self, *args, semantic_schema: schemas.SemanticSchema | None = None, source_info: dict[str, str|None] | None = None, **kwargs): + super().__init__(*args, **kwargs) + self.schema = semantic_schema + if source_info is None: + source_info = {} + self.source_info = source_info + + + +class PacketConverter: + def __init__(self, python_schema: schemas.PythonSchema, registry: SemanticTypeRegistry, include_source_info: bool = True): + self.python_schema = python_schema + self.registry = registry + + self.semantic_schema = schemas.from_python_schema_to_semantic_schema( + python_schema, registry + ) + + self.include_source_info = include_source_info + + self.arrow_schema = schemas.from_semantic_schema_to_arrow_schema( + self.semantic_schema, include_source_info=self.include_source_info + ) + + + + self.key_handlers: dict[str, TypeHandler] = {} + + self.expected_key_set = set(self.python_schema.keys()) + + for key, (_, semantic_type) in self.semantic_schema.items(): + if semantic_type is None: + continue + handler = registry.get_handler_by_semantic_type(semantic_type) + if handler is None: + raise ValueError( + f"No handler found for semantic type '{semantic_type}' in key '{key}'" + ) + self.key_handlers[key] = handler + + def _check_key_consistency(self, keys): + """Check if the provided keys match the expected keys.""" + keys_set = set(keys) + if keys_set != self.expected_key_set: + missing_keys = self.expected_key_set - keys_set + extra_keys = keys_set - self.expected_key_set + error_parts = [] + if missing_keys: + error_parts.append(f"Missing keys: {missing_keys}") + if extra_keys: + error_parts.append(f"Extra keys: {extra_keys}") + + raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") + + def from_python_packet_to_semantic_packet(self, python_packet: PacketLike) -> SemanticPacket: + """Convert a Python packet to a semantic packet. + + Args: + python_packet: Dictionary mapping parameter names to Python values + + Returns: + Packet with values converted to semantic types + + Raises: + KeyError: If packet keys don't match the expected type_info keys + TypeError: If value type doesn't match expected type + ValueError: If conversion fails + """ + # Validate packet keys + semantic_packet = SemanticPacket(python_packet, semantic_schema=self.semantic_schema, source_info=getattr(python_packet, "source_info", None)) + self._check_key_consistency(set(semantic_packet.keys())) + + # convert from storage to Python types for semantic types + for key, handler in self.key_handlers.items(): + try: + semantic_packet[key] = handler.python_to_storage( + semantic_packet[key] + ) + except Exception as e: + raise ValueError(f"Failed to convert value for '{key}': {e}") from e + + return semantic_packet + + + + def from_python_packet_to_arrow_table(self, python_packet: PacketLike) -> pa.Table: + """Convert a Python packet to an Arrow table. + + Args: + python_packet: Dictionary mapping parameter names to Python values + + Returns: + Arrow table representation of the packet + """ + semantic_packet = self.from_python_packet_to_semantic_packet(python_packet) + return self.from_semantic_packet_to_arrow_table(semantic_packet) + + def from_semantic_packet_to_arrow_table(self, semantic_packet: SemanticPacket) -> pa.Table: + """Convert a semantic packet to an Arrow table. + + Args: + semantic_packet: SemanticPacket with values to convert + + Returns: + Arrow table representation of the packet + """ + arrays = [] + for field in self.arrow_schema: + value = semantic_packet.get(field.name, None) + arrays.append(pa.array([value], type=field.type)) + + if self.include_source_info: + for field, value in semantic_packet.source_info.items(): + arrays.append(pa.array([value], type=pa.large_string())) + + return pa.Table.from_arrays(arrays, schema=self.arrow_schema) + + def from_arrow_table_to_semantic_packets(self, arrow_table: pa.Table) -> Collection[SemanticPacket]: + """Convert an Arrow table to a semantic packet. + + Args: + arrow_table: Arrow table representation of the packet + + Returns: + SemanticPacket with values converted from Arrow types + """ + # TODO: this is a crude check, implement more robust one to check that + # schema matches what's expected + if not arrow_table.schema.equals(self.arrow_schema): + raise ValueError("Arrow table schema does not match expected schema") + + semantic_packets_contents = arrow_table.to_pylist() + + semantic_packets = [] + for all_packet_content in semantic_packets_contents: + packet_content = {k: v for k, v in all_packet_content.items() if k in self.expected_key_set} + source_info = {k.strip('_source_info_'): v for k, v in all_packet_content.items() if k.startswith('_source_info_')} + semantic_packets.append(SemanticPacket(packet_content, _semantic_schema=self.semantic_schema, _source_info=source_info)) + + return semantic_packets + + def from_semantic_packet_to_python_packet(self, semantic_packet: SemanticPacket) -> Packet: + """Convert a semantic packet to a Python packet. + + Args: + semantic_packet: SemanticPacket with values to convert + + Returns: + Python packet representation of the semantic packet + """ + # Validate packet keys + python_packet = Packet(semantic_packet, typespec=self.python_schema, source_info=semantic_packet.source_info) + packet_keys = set(python_packet.keys()) + self._check_key_consistency(packet_keys) + + for key, handler in self.key_handlers.items(): + try: + python_packet[key] = handler.storage_to_python( + python_packet[key] + ) + except Exception as e: + raise ValueError(f"Failed to convert value for '{key}': {e}") from e + + return python_packet + + def from_arrow_table_to_python_packets(self, arrow_table: pa.Table) -> list[Packet]: + """Convert an Arrow table to a list of Python packets. + + Args: + arrow_table: Arrow table representation of the packets + + Returns: + List of Python packets converted from the Arrow table + """ + semantic_packets = self.from_arrow_table_to_semantic_packets(arrow_table) + return [self.from_semantic_packet_to_python_packet(sp) for sp in semantic_packets] + diff --git a/src/orcapod/types/registry.py b/src/orcapod/types/registry.py deleted file mode 100644 index 6b56183..0000000 --- a/src/orcapod/types/registry.py +++ /dev/null @@ -1,437 +0,0 @@ -from collections.abc import Callable, Collection, Sequence, Mapping -import logging -from optparse import Values -from typing import Any -import pyarrow as pa -from orcapod.types import Packet -from .core import TypeHandler, TypeInfo, TypeSpec - -# This mapping is expected to be stable -# Be sure to test this assumption holds true -DEFAULT_ARROW_TYPE_LUT = { - int: pa.int64(), - float: pa.float64(), - str: pa.string(), - bool: pa.bool_(), -} - -logger = logging.getLogger(__name__) - - -class TypeRegistry: - """Registry that manages type handlers with semantic type names.""" - - def __init__(self): - self._handlers: dict[ - type, tuple[TypeHandler, str] - ] = {} # Type -> (Handler, semantic_name) - self._semantic_handlers: dict[str, TypeHandler] = {} # semantic_name -> Handler - - def register( - self, - semantic_name: str, - handler: TypeHandler, - explicit_types: type | tuple[type, ...] | None = None, - override: bool = False, - ): - """Register a handler with a semantic type name. - - Args: - semantic_name: Identifier for this semantic type (e.g., 'path', 'uuid') - handler: The type handler instance - explicit_types: Optional override of types to register for (if different from handler's supported_types) - override: If True, allow overriding existing registration for the same semantic name and Python type(s) - """ - # Determine which types to register for - if explicit_types is not None: - types_to_register = ( - explicit_types - if isinstance(explicit_types, tuple) - else (explicit_types,) - ) - else: - supported = handler.python_types() - types_to_register = ( - supported if isinstance(supported, tuple) else (supported,) - ) - - # Register handler for each type - for python_type in types_to_register: - if python_type in self._handlers and not override: - existing_semantic = self._handlers[python_type][1] - # TODO: handle overlapping registration more gracefully - raise ValueError( - f"Type {python_type} already registered with semantic type '{existing_semantic}'" - ) - - self._handlers[python_type] = (handler, semantic_name) - - # Register by semantic name - if semantic_name in self._semantic_handlers and not override: - raise ValueError(f"Semantic type '{semantic_name}' already registered") - - self._semantic_handlers[semantic_name] = handler - - def get_handler(self, python_type: type) -> TypeHandler | None: - """Get handler for a Python type.""" - handler_info = self._handlers.get(python_type) - return handler_info[0] if handler_info else None - - def get_semantic_name(self, python_type: type) -> str | None: - """Get semantic name for a Python type.""" - handler_info = self._handlers.get(python_type) - return handler_info[1] if handler_info else None - - def get_type_info(self, python_type: type) -> TypeInfo | None: - """Get TypeInfo for a Python type.""" - handler = self.get_handler(python_type) - if handler is None: - return None - semantic_name = self.get_semantic_name(python_type) - return TypeInfo( - python_type=python_type, - arrow_type=handler.storage_type(), - semantic_type=semantic_name, - handler=handler, - ) - - def get_handler_by_semantic_name(self, semantic_name: str) -> TypeHandler | None: - """Get handler by semantic name.""" - return self._semantic_handlers.get(semantic_name) - - def __contains__(self, python_type: type) -> bool: - """Check if a Python type is registered.""" - return python_type in self._handlers - - -class PacketConverter: - def __init__(self, python_type_spec: TypeSpec, registry: TypeRegistry): - self.python_type_spec = python_type_spec - self.registry = registry - - # Lookup handlers and type info for fast access - self.handlers: dict[str, TypeHandler] = {} - self.storage_type_info: dict[str, TypeInfo] = {} - - self.expected_key_set = set(python_type_spec.keys()) - - # prepare the corresponding arrow table schema with metadata - self.keys_with_handlers, self.schema = create_schema_from_python_type_info( - python_type_spec, registry - ) - - self.semantic_type_lut = get_metadata_from_schema(self.schema, b"semantic_type") - - def _check_key_consistency(self, keys): - """Check if the provided keys match the expected keys.""" - keys_set = set(keys) - if keys_set != self.expected_key_set: - missing_keys = self.expected_key_set - keys_set - extra_keys = keys_set - self.expected_key_set - error_parts = [] - if missing_keys: - error_parts.append(f"Missing keys: {missing_keys}") - if extra_keys: - error_parts.append(f"Extra keys: {extra_keys}") - - raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") - - def _to_storage_packet(self, packet: Packet) -> dict[str, Any]: - """Convert packet to storage representation. - - Args: - packet: Dictionary mapping parameter names to Python values - - Returns: - Dictionary with same keys but values converted to storage format - - Raises: - KeyError: If packet keys don't match the expected type_info keys - TypeError: If value type doesn't match expected type - ValueError: If conversion fails - """ - # Validate packet keys - packet_keys = set(packet.keys()) - - self._check_key_consistency(packet_keys) - - # Convert each value - storage_packet: dict[str, Any] = dict(packet) # Start with a copy of the packet - - for key, handler in self.keys_with_handlers: - try: - storage_packet[key] = handler.python_to_storage(storage_packet[key]) - except Exception as e: - raise ValueError(f"Failed to convert value for '{key}': {e}") from e - - return storage_packet - - def _from_storage_packet(self, storage_packet: Mapping[str, Any]) -> Packet: - """Convert storage packet back to Python packet. - - Args: - storage_packet: Dictionary with values in storage format - - Returns: - Packet with values converted back to Python types - - Raises: - KeyError: If storage packet keys don't match the expected type_info keys - TypeError: If value type doesn't match expected type - ValueError: If conversion fails - """ - # Validate storage packet keys - storage_keys = set(storage_packet.keys()) - - self._check_key_consistency(storage_keys) - - # Convert each value back to Python type - packet: Packet = dict(storage_packet) - - for key, handler in self.keys_with_handlers: - try: - packet[key] = handler.storage_to_python(storage_packet[key]) - except Exception as e: - raise ValueError(f"Failed to convert value for '{key}': {e}") from e - - return packet - - def to_arrow_table(self, packet: Packet | Sequence[Packet]) -> pa.Table: - """Convert packet to PyArrow Table with field metadata. - - Args: - packet: Dictionary mapping parameter names to Python values - - Returns: - PyArrow Table with the packet data as a single row - """ - # Convert packet to storage format - if not isinstance(packet, Sequence): - packets = [packet] - else: - packets = packet - - storage_packets = [self._to_storage_packet(p) for p in packets] - - # Create arrays - arrays = [] - for field in self.schema: - values = [p[field.name] for p in storage_packets] - array = pa.array(values, type=field.type) - arrays.append(array) - - return pa.Table.from_arrays(arrays, schema=self.schema) - - def from_arrow_table( - self, table: pa.Table, verify_semantic_equivalence: bool = True - ) -> list[Packet]: - """Convert Arrow table to packet with field metadata. - - Args: - table: PyArrow Table with metadata - - Returns: - List of packets converted from the Arrow table - """ - # Check for consistency in the semantic type mapping: - semantic_type_info = get_metadata_from_schema(table.schema, b"semantic_type") - - if semantic_type_info != self.semantic_type_lut: - if not verify_semantic_equivalence: - logger.warning( - "Arrow table semantic types do not match expected type registry. " - f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" - ) - else: - raise ValueError( - "Arrow table semantic types do not match expected type registry. " - f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" - ) - - # Create packets from the Arrow table - # TODO: make this more efficient - storage_packets: list[Packet] = arrow_to_dicts(table) # type: ignore - if not self.keys_with_handlers: - # no special handling required - return storage_packets - - return [self._from_storage_packet(packet) for packet in storage_packets] - - -def arrow_to_dicts(table: pa.Table) -> list[dict[str, Any]]: - """ - Convert Arrow table to dictionary or list of dictionaries. - By default returns a list of dictionaries (one per row) with column names as keys. - If `collapse_singleton` is True, return a single dictionary for single-row tables. - Args: - table: PyArrow Table to convert - collapse_singleton: If True, return a single dictionary for single-row tables. Defaults to False. - Returns: - A dictionary if singleton and collapse_singleton=True. Otherwise, list of dictionaries for multi-row tables. - """ - if len(table) == 0: - return [] - - # Multiple rows: return list of dicts (one per row) - return [ - {col_name: table.column(col_name)[i].as_py() for col_name in table.column_names} - for i in range(len(table)) - ] - - -def get_metadata_from_schema( - schema: pa.Schema, metadata_field: bytes -) -> dict[str, str]: - """ - Extract metadata from Arrow schema fields. Metadata value will be utf-8 decoded. - Args: - schema: PyArrow Schema to extract metadata from - metadata_field: Metadata field to extract (e.g., b'semantic_type') - Returns: - Dictionary mapping field names to their metadata values - """ - metadata = {} - for field in schema: - if field.metadata and metadata_field in field.metadata: - metadata[field.name] = field.metadata[metadata_field].decode("utf-8") - return metadata - - -def create_schema_from_python_type_info( - python_type_spec: TypeSpec, - registry: TypeRegistry, - arrow_type_lut: dict[type, pa.DataType] | None = None, -) -> tuple[list[tuple[str, TypeHandler]], pa.Schema]: - if arrow_type_lut is None: - arrow_type_lut = DEFAULT_ARROW_TYPE_LUT - keys_with_handlers: list[tuple[str, TypeHandler]] = [] - schema_fields = [] - for key, python_type in python_type_spec.items(): - type_info = registry.get_type_info(python_type) - - field_metadata = {} - if type_info and type_info.semantic_type: - field_metadata["semantic_type"] = type_info.semantic_type - keys_with_handlers.append((key, type_info.handler)) - arrow_type = type_info.arrow_type - else: - arrow_type = arrow_type_lut.get(python_type) - if arrow_type is None: - raise ValueError( - f"Direct support for Python type {python_type} is not provided. Register a handler to work with {python_type}" - ) - - schema_fields.append(pa.field(key, arrow_type, metadata=field_metadata)) - return keys_with_handlers, pa.schema(schema_fields) - - -def arrow_table_to_packets( - table: pa.Table, - registry: TypeRegistry, -) -> list[Packet]: - """Convert Arrow table to packet with field metadata. - - Args: - packet: Dictionary mapping parameter names to Python values - - Returns: - PyArrow Table with the packet data as a single row - """ - packets: list[Packet] = [] - - # prepare converter for each field - - def no_op(x) -> Any: - return x - - converter_lut = {} - for field in table.schema: - if field.metadata and b"semantic_type" in field.metadata: - semantic_type = field.metadata[b"semantic_type"].decode("utf-8") - if semantic_type: - handler = registry.get_handler_by_semantic_name(semantic_type) - if handler is None: - raise ValueError( - f"No handler registered for semantic type '{semantic_type}'" - ) - converter_lut[field.name] = handler.storage_to_python - - # Create packets from the Arrow table - # TODO: make this more efficient - for row in range(table.num_rows): - packet: Packet = {} - for field in table.schema: - value = table.column(field.name)[row].as_py() - packet[field.name] = converter_lut.get(field.name, no_op)(value) - packets.append(packet) - - return packets - - -def is_packet_supported( - python_type_info: TypeSpec, registry: TypeRegistry, type_lut: dict | None = None -) -> bool: - """Check if all types in the packet are supported by the registry or known to the default lut.""" - if type_lut is None: - type_lut = {} - return all( - python_type in registry or python_type in type_lut - for python_type in python_type_info.values() - ) - - -def create_arrow_table_with_meta( - storage_packet: dict[str, Any], type_info: dict[str, TypeInfo] -): - """Create an Arrow table with metadata from a storage packet. - - Args: - storage_packet: Dictionary with values in storage format - type_info: Dictionary mapping parameter names to TypeInfo objects - - Returns: - PyArrow Table with metadata - """ - schema_fields = [] - for key, type_info_obj in type_info.items(): - field_metadata = {} - if type_info_obj.semantic_type: - field_metadata["semantic_type"] = type_info_obj.semantic_type - - field = pa.field(key, type_info_obj.arrow_type, metadata=field_metadata) - schema_fields.append(field) - - schema = pa.schema(schema_fields) - - arrays = [] - for field in schema: - value = storage_packet[field.name] - array = pa.array([value], type=field.type) - arrays.append(array) - - return pa.Table.from_arrays(arrays, schema=schema) - - -def retrieve_storage_packet_from_arrow_with_meta( - arrow_table: pa.Table, -) -> dict[str, Any]: - """Retrieve storage packet from Arrow table with metadata. - - Args: - arrow_table: PyArrow Table with metadata - - Returns: - Dictionary representing the storage packet - """ - storage_packet = {} - for field in arrow_table.schema: - # Extract value from Arrow array - array = arrow_table.column(field.name) - if array.num_chunks > 0: - value = array.chunk(0).as_py()[0] # Get first value - else: - value = None # Handle empty arrays - - storage_packet[field.name] = value - - return storage_packet diff --git a/src/orcapod/types/schemas.py b/src/orcapod/types/schemas.py new file mode 100644 index 0000000..4f78ca5 --- /dev/null +++ b/src/orcapod/types/schemas.py @@ -0,0 +1,267 @@ + +from orcapod.types import TypeSpec +from orcapod.types.semantic_type_registry import SemanticTypeRegistry +from typing import Any +import pyarrow as pa +import datetime + +# This mapping is expected to be stable +# Be sure to test this assumption holds true +DEFAULT_ARROW_TYPE_LUT = { + int: pa.int64(), + float: pa.float64(), + str: pa.large_string(), + bool: pa.bool_(), +} + +def python_to_arrow_type(python_type: type) -> pa.DataType: + if python_type in DEFAULT_ARROW_TYPE_LUT: + return DEFAULT_ARROW_TYPE_LUT[python_type] + raise TypeError(f"Converstion of python type {python_type} is not supported yet") + +def arrow_to_python_type(arrow_type: pa.DataType) -> type: + if pa.types.is_integer(arrow_type): + return int + elif pa.types.is_floating(arrow_type): + return float + elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return str + elif pa.types.is_boolean(arrow_type): + return bool + elif pa.types.is_date(arrow_type): + return datetime.date + elif pa.types.is_timestamp(arrow_type): + return datetime.datetime + elif pa.types.is_binary(arrow_type): + return bytes + else: + raise TypeError(f"Conversion of arrow type {arrow_type} is not supported") + + + +# class PythonSchema(dict[str, type]): +# """ +# A schema for Python data types, mapping string keys to Python types. + +# This is used to define the expected structure of data packets in OrcaPod. + +# Attributes +# ---------- +# keys : str +# The keys of the schema. +# values : type +# The types corresponding to each key. + +# Examples +# -------- +# >>> schema = PythonSchema(name=str, age=int) +# >>> print(schema) +# {'name': , 'age': } +# """ + +PythonSchema = TypeSpec + + +class SemanticSchema(dict[str, tuple[type, str|None]]): + """ + A schema for semantic types, mapping string keys to tuples of Python types and optional metadata. + + This is used to define the expected structure of data packets with semantic types in OrcaPod. + + Attributes + ---------- + keys : str + The keys of the schema. + values : tuple[type, str|None] + The types and optional semantic type corresponding to each key. + + Examples + -------- + >>> schema = SemanticSchema(image=(str, 'path'), age=(int, None)) + >>> print(schema) + {'image': (, 'path'), 'age': (, None)} + """ + def get_store_type(self, key: str) -> type | None: + """ + Get the storage type for a given key in the schema. + + Parameters + ---------- + key : str + The key for which to retrieve the storage type. + + Returns + ------- + type | None + The storage type associated with the key, or None if not found. + """ + return self.get(key, (None, None))[0] + + def get_semantic_type(self, key: str) -> str | None: + """ + Get the semantic type for a given key in the schema. + + Parameters + ---------- + key : str + The key for which to retrieve the semantic type. + + Returns + ------- + str | None + The semantic type associated with the key, or None if not found. + """ + return self.get(key, (None, None))[1] + + +def from_python_schema_to_semantic_schema( + python_schema: PythonSchema, + semantic_type_registry: SemanticTypeRegistry, +) -> SemanticSchema: + """ + Convert a Python schema to a semantic schema using the provided semantic type registry. + + Parameters + ---------- + python_schema : PythonSchema + The schema to convert, mapping keys to Python types. + semantic_type_registry : SemanticTypeRegistry + The registry containing semantic type information. + + Returns + ------- + SemanticSchema + A new schema mapping keys to tuples of Python types and optional semantic type identifiers. + + Examples + -------- + >>> python_schema = PythonSchema(name=str, age=int) + >>> semantic_schema = from_python_schema_to_semantic_schema(python_schema, registry) + >>> print(semantic_schema) + {'name': (, None), 'age': (, None)} + """ + semantic_schema = {} + for key, python_type in python_schema.items(): + if python_type in semantic_type_registry: + type_info = semantic_type_registry.get_type_info(python_type) + assert type_info is not None, f"Type {python_type} should be found in the registry as `in` returned True" + semantic_schema[key] = (type_info.storage_type, type_info.semantic_type) + else: + semantic_schema[key] = (python_type, None) + return SemanticSchema(semantic_schema) + +def from_semantic_schema_to_python_schema( + semantic_schema: SemanticSchema, + semantic_type_registry: SemanticTypeRegistry, +) -> PythonSchema: + """ + Convert a semantic schema to a Python schema using the provided semantic type registry. + + Parameters + ---------- + semantic_schema : SemanticSchema + The schema to convert, mapping keys to tuples of Python types and optional semantic type identifiers. + semantic_type_registry : SemanticTypeRegistry + The registry containing semantic type information. + + Returns + ------- + PythonSchema + A new schema mapping keys to Python types. + + Examples + -------- + >>> semantic_schema = SemanticSchema(name=(str, None), age=(int, None)) + >>> python_schema = from_semantic_schema_to_python_schema(semantic_schema, registry) + >>> print(python_schema) + {'name': , 'age': } + """ + python_schema = {} + for key, (python_type, semantic_type) in semantic_schema.items(): + if semantic_type is not None: + # If the semantic type is registered, use the corresponding Python type + python_type = semantic_type_registry.get_python_type(semantic_type) + python_schema[key] = python_type + return python_schema + +def from_semantic_schema_to_arrow_schema( + semantic_schema: SemanticSchema, + include_source_info: bool = True, +) -> pa.Schema: + """ + Convert a semantic schema to an Arrow schema. + + Parameters + ---------- + semantic_schema : SemanticSchema + The schema to convert, mapping keys to tuples of Python types and optional semantic type identifiers. + + Returns + ------- + dict[str, type] + A new schema mapping keys to Arrow-compatible types. + + Examples + -------- + >>> semantic_schema = SemanticSchema(name=(str, None), age=(int, None)) + >>> arrow_schema = from_semantic_schema_to_arrow_schema(semantic_schema) + >>> print(arrow_schema) + {'name': str, 'age': int} + """ + fields = [] + for field_name, (python_type, semantic_type) in semantic_schema.items(): + arrow_type = DEFAULT_ARROW_TYPE_LUT[python_type] + field_metadata = {b"semantic_type": semantic_type.encode('utf-8')} if semantic_type else {} + fields.append(pa.field(field_name, arrow_type, metadata=field_metadata)) + + if include_source_info: + for field in semantic_schema: + field_metadata = {b'field_type': b'source_info'} + fields.append(pa.field(f'_source_info_{field}', pa.large_string(), metadata=field_metadata)) + + return pa.schema(fields) + +def from_arrow_schema_to_semantic_schema( + arrow_schema: pa.Schema, + semantic_type_registry: SemanticTypeRegistry | None = None, +) -> SemanticSchema: + """ + Convert an Arrow schema to a semantic schema. + + Parameters + ---------- + arrow_schema : pa.Schema + The schema to convert, containing fields with metadata. + + Returns + ------- + SemanticSchema + A new schema mapping keys to tuples of Python types and optional semantic type identifiers. + + Examples + -------- + >>> arrow_schema = pa.schema([pa.field('name', pa.string(), metadata={'semantic_type': 'name'}), + ... pa.field('age', pa.int64(), metadata={'semantic_type': 'age'})]) + >>> semantic_schema = from_arrow_schema_to_semantic_schema(arrow_schema) + >>> print(semantic_schema) + {'name': (str, 'name'), 'age': (int, 'age')} + """ + semantic_schema = {} + for field in arrow_schema: + if field.metadata.get(b'field_type', b'') == b'source_info': + # Skip source info fields + continue + semantic_type = field.metadata.get(b'semantic_type', None) + semantic_type = semantic_type.decode() if semantic_type else None + if semantic_type: + if semantic_type_registry is None: + raise ValueError("Semantic type registry must be provided for semantic types") + python_type = semantic_type_registry.get_python_type(semantic_type) + if python_type is None: + raise ValueError(f"Semantic type '{semantic_type}' is not registered in the registry") + else: + python_type = arrow_to_python_type(field.type) + + semantic_schema[field.name] = (python_type, semantic_type) + return SemanticSchema(semantic_schema) + diff --git a/src/orcapod/types/handlers.py b/src/orcapod/types/semantic_type_handlers.py similarity index 92% rename from src/orcapod/types/handlers.py rename to src/orcapod/types/semantic_type_handlers.py index ecbdfba..a15f9d5 100644 --- a/src/orcapod/types/handlers.py +++ b/src/orcapod/types/semantic_type_handlers.py @@ -9,7 +9,7 @@ class PathHandler: """Handler for pathlib.Path objects, stored as strings.""" - def python_types(self) -> type: + def python_type(self) -> type: return Path def storage_type(self) -> pa.DataType: @@ -25,7 +25,7 @@ def storage_to_python(self, value: str) -> Path | None: class UUIDHandler: """Handler for UUID objects, stored as strings.""" - def python_types(self) -> type: + def python_type(self) -> type: return UUID def storage_type(self) -> pa.DataType: @@ -41,7 +41,7 @@ def storage_to_python(self, value: str) -> UUID | None: class DecimalHandler: """Handler for Decimal objects, stored as strings.""" - def python_types(self) -> type: + def python_type(self) -> type: return Decimal def storage_type(self) -> pa.DataType: @@ -61,7 +61,7 @@ def __init__(self, python_type: type, arrow_type: pa.DataType): self._python_type = python_type self._arrow_type = arrow_type - def python_types(self) -> type: + def python_type(self) -> type: return self._python_type def storage_type(self) -> pa.DataType: @@ -80,7 +80,7 @@ class DirectArrowHandler: def __init__(self, arrow_type: pa.DataType): self._arrow_type = arrow_type - def python_types(self) -> type: + def python_type(self) -> type: return self._arrow_type def storage_type(self) -> pa.DataType: @@ -96,7 +96,7 @@ def storage_to_python(self, value: Any) -> Any: class DateTimeHandler: """Handler for datetime objects.""" - def python_types(self) -> tuple[type, ...]: + def python_type(self) -> type: return (datetime, date, time) # Handles multiple related types def storage_type(self) -> pa.DataType: diff --git a/src/orcapod/types/semantic_type_registry.py b/src/orcapod/types/semantic_type_registry.py new file mode 100644 index 0000000..d954891 --- /dev/null +++ b/src/orcapod/types/semantic_type_registry.py @@ -0,0 +1,468 @@ +from collections.abc import Callable, Collection, Sequence, Mapping +import logging +from optparse import Values +from typing import Any, get_origin, get_args +from types import UnionType +import pyarrow as pa +from orcapod.types.packets import Packet, PacketLike +from .core import TypeHandler, TypeSpec +from dataclasses import dataclass + +# This mapping is expected to be stable +# Be sure to test this assumption holds true +DEFAULT_ARROW_TYPE_LUT = { + int: pa.int64(), + float: pa.float64(), + str: pa.string(), + bool: pa.bool_(), +} + +logger = logging.getLogger(__name__) + + +# TODO: reconsider the need for this dataclass as its information is superfluous +# to the registration of the handler into the registry. +@dataclass +class TypeInfo: + python_type: type + storage_type: type + semantic_type: str | None # name under which the type is registered + handler: "TypeHandler" + + +class SemanticTypeRegistry: + """Registry that manages type handlers with semantic type names.""" + + def __init__(self): + self._handlers: dict[ + type, tuple[TypeHandler, str] + ] = {} # PythonType -> (Handler, semantic_name) + self._semantic_handlers: dict[str, TypeHandler] = {} # semantic_name -> Handler + self._semantic_to_python_lut: dict[str, type] = {} # semantic_name -> Python type + + def register( + self, + semantic_type: str, + handler: TypeHandler, + ): + """Register a handler with a semantic type name. + + Args: + semantic_name: Identifier for this semantic type (e.g., 'path', 'uuid') + handler: The type handler instance + explicit_types: Optional override of types to register for (if different from handler's supported_types) + override: If True, allow overriding existing registration for the same semantic name and Python type(s) + """ + # Determine which types to register for + + python_type = handler.python_type() + + # Register handler for each type + if python_type in self._handlers: + existing_semantic = self._handlers[python_type][1] + # TODO: handle overlapping registration more gracefully + raise ValueError( + f"Type {python_type} already registered with semantic type '{existing_semantic}'" + ) + + # Register by semantic name + if semantic_type in self._semantic_handlers: + raise ValueError(f"Semantic type '{semantic_type}' already registered") + + self._handlers[python_type] = (handler, semantic_type) + self._semantic_handlers[semantic_type] = handler + self._semantic_to_python_lut[semantic_type] = python_type + + def get_python_type(self, semantic_type: str) -> type | None: + """Get Python type for a semantic type.""" + return self._semantic_to_python_lut.get(semantic_type) + + + + def get_semantic_type(self, python_type: type) -> str | None: + """Get semantic type for a Python type.""" + handler_info = self._handlers.get(python_type) + return handler_info[1] if handler_info else None + + def get_handler(self, python_type: type) -> TypeHandler | None: + """Get handler for a Python type.""" + handler_info = self._handlers.get(python_type) + return handler_info[0] if handler_info else None + + def get_handler_by_semantic_type(self, semantic_type: str) -> TypeHandler | None: + """Get handler by semantic type.""" + return self._semantic_handlers.get(semantic_type) + + + def get_type_info(self, python_type: type) -> TypeInfo | None: + """Get TypeInfo for a Python type.""" + handler = self.get_handler(python_type) + if handler is None: + return None + semantic_type = self.get_semantic_type(python_type) + return TypeInfo( + python_type=python_type, + storage_type=handler.storage_type(), + semantic_type=semantic_type, + handler=handler, + ) + + + def __contains__(self, python_type: type) -> bool: + """Check if a Python type is registered.""" + return python_type in self._handlers + + + + + + +# Below is a collection of functions that handles converting between various aspects of Python packets and Arrow tables. +# Here for convenience, any Python dictionary with str keys and supported Python values are referred to as a packet. + + +# Conversions are: +# python packet <-> storage packet <-> arrow table +# python typespec <-> storage typespec <-> arrow schema +# +# python packet <-> storage packet requires the use of SemanticTypeRegistry +# conversion between storage packet <-> arrow table requires info about semantic_type + + +# # Storage packet <-> Arrow table + +# def stroage_typespec_to_arrow_schema(storage_typespec:TypeSpec, semantic_type_info: dict[str, str]|None = None) -> pa.Schema: +# """Convert storage typespec to Arrow Schema with semantic_type metadata.""" +# """Convert storage typespec to PyArrow Schema with semantic_type metadata.""" +# if semantic_type_info is None: +# semantic_type_info = {} + +# fields = [] +# for field_name, field_type in storage_typespec.items(): +# arrow_type = python_to_pyarrow_type(field_type) +# semantic_type = semantic_type_info.get(field_name, None) +# field_metadata = {"semantic_type": semantic_type} if semantic_type else {} +# fields.append(pa.field(field_name, arrow_type, metadata=field_metadata)) +# return pa.schema(fields) + +# def arrow_schema_to_storage_typespec(schema: pa.Schema) -> tuple[TypeSpec, dict[str, str]|None]: +# """Convert Arrow Schema to storage typespec and semantic type metadata.""" +# typespec = {} +# semantic_type_info = {} + +# for field in schema: +# field_type = field.type +# typespec[field.name] = field_type.to_pandas_dtype() # Convert Arrow type to Pandas dtype +# if field.metadata and b"semantic_type" in field.metadata: +# semantic_type_info[field.name] = field.metadata[b"semantic_type"].decode("utf-8") + +# return typespec, semantic_type_info + + +# def storage_packet_to_arrow_table( +# storage_packet: PacketLike, +# typespec: TypeSpec | None = None, +# semantic_type_info: dict[str, str] | None = None, + + + +# # TypeSpec + TypeRegistry + ArrowLUT -> Arrow Schema (annotated with semantic_type) + +# # + + + + + + +# # TypeSpec <-> Arrow Schema + +# def schema_from_typespec(typespec: TypeSpec, registry: SemanticTypeRegistry, metadata_info: dict | None = None) -> pa.Schema: +# """Convert TypeSpec to PyArrow Schema.""" +# if metadata_info is None: +# metadata_info = {} + +# fields = [] +# for field_name, field_type in typespec.items(): +# type_info = registry.get_type_info(field_type) +# if type_info is None: +# raise ValueError(f"No type info registered for {field_type}") +# fields.append(pa.field(field_name, type_info.arrow_type, metadata={ +# "semantic_type": type_info.semantic_type +# })) +# return pa.schema(fields) + +# def create_schema_from_typespec( +# typespec: TypeSpec, +# registry: SemanticTypeRegistry, +# metadata_info: dict | None = None, +# arrow_type_lut: dict[type, pa.DataType] | None = None, +# ) -> tuple[list[tuple[str, TypeHandler]], pa.Schema]: +# if metadata_info is None: +# metadata_info = {} +# if arrow_type_lut is None: +# arrow_type_lut = DEFAULT_ARROW_TYPE_LUT + +# keys_with_handlers: list[tuple[str, TypeHandler]] = [] +# schema_fields = [] +# for key, python_type in typespec.items(): +# type_info = registry.get_type_info(python_type) + +# field_metadata = {} +# if type_info and type_info.semantic_type: +# field_metadata["semantic_type"] = type_info.semantic_type +# keys_with_handlers.append((key, type_info.handler)) +# arrow_type = type_info.arrow_type +# else: +# arrow_type = arrow_type_lut.get(python_type) +# if arrow_type is None: +# raise ValueError( +# f"Direct support for Python type {python_type} is not provided. Register a handler to work with {python_type}" +# ) + +# schema_fields.append(pa.field(key, arrow_type, metadata=field_metadata)) +# return keys_with_handlers, pa.schema(schema_fields) + + + +# def arrow_table_to_packets( +# table: pa.Table, +# registry: SemanticTypeRegistry, +# ) -> list[Packet]: +# """Convert Arrow table to packet with field metadata. + +# Args: +# packet: Dictionary mapping parameter names to Python values + +# Returns: +# PyArrow Table with the packet data as a single row +# """ +# packets: list[Packet] = [] + +# # prepare converter for each field + +# def no_op(x) -> Any: +# return x + +# converter_lut = {} +# for field in table.schema: +# if field.metadata and b"semantic_type" in field.metadata: +# semantic_type = field.metadata[b"semantic_type"].decode("utf-8") +# if semantic_type: +# handler = registry.get_handler_by_semantic_name(semantic_type) +# if handler is None: +# raise ValueError( +# f"No handler registered for semantic type '{semantic_type}'" +# ) +# converter_lut[field.name] = handler.storage_to_python + +# # Create packets from the Arrow table +# # TODO: make this more efficient +# for row in range(table.num_rows): +# packet: Packet = Packet() +# for field in table.schema: +# value = table.column(field.name)[row].as_py() +# packet[field.name] = converter_lut.get(field.name, no_op)(value) +# packets.append(packet) + +# return packets + + +# def create_arrow_table_with_meta( +# storage_packet: dict[str, Any], type_info: dict[str, TypeInfo] +# ): +# """Create an Arrow table with metadata from a storage packet. + +# Args: +# storage_packet: Dictionary with values in storage format +# type_info: Dictionary mapping parameter names to TypeInfo objects + +# Returns: +# PyArrow Table with metadata +# """ +# schema_fields = [] +# for key, type_info_obj in type_info.items(): +# field_metadata = {} +# if type_info_obj.semantic_type: +# field_metadata["semantic_type"] = type_info_obj.semantic_type + +# field = pa.field(key, type_info_obj.arrow_type, metadata=field_metadata) +# schema_fields.append(field) + +# schema = pa.schema(schema_fields) + +# arrays = [] +# for field in schema: +# value = storage_packet[field.name] +# array = pa.array([value], type=field.type) +# arrays.append(array) + +# return pa.Table.from_arrays(arrays, schema=schema) + + +# def retrieve_storage_packet_from_arrow_with_meta( +# arrow_table: pa.Table, +# ) -> dict[str, Any]: +# """Retrieve storage packet from Arrow table with metadata. + +# Args: +# arrow_table: PyArrow Table with metadata + +# Returns: +# Dictionary representing the storage packet +# """ +# storage_packet = {} +# for field in arrow_table.schema: +# # Extract value from Arrow array +# array = arrow_table.column(field.name) +# if array.num_chunks > 0: +# value = array.chunk(0).as_py()[0] # Get first value +# else: +# value = None # Handle empty arrays + +# storage_packet[field.name] = value + +# return storage_packet + +# def typespec_to_schema_with_metadata(typespec: TypeSpec, field_metadata: dict|None = None) -> pa.Schema: +# """Convert TypeSpec to PyArrow Schema""" +# fields = [] +# for field_name, field_type in typespec.items(): +# arrow_type = python_to_pyarrow_type(field_type) +# fields.append(pa.field(field_name, arrow_type)) +# return pa.schema(fields) + +# def python_to_pyarrow_type(python_type: type, strict:bool=True) -> pa.DataType: +# """Convert Python type (including generics) to PyArrow type""" +# # For anywhere we need to store str value, we use large_string as is done in Polars + +# # Handle basic types first +# basic_mapping = { +# int: pa.int64(), +# float: pa.float64(), +# str: pa.large_string(), +# bool: pa.bool_(), +# bytes: pa.binary(), +# } + +# if python_type in basic_mapping: +# return basic_mapping[python_type] + +# # Handle generic types +# origin = get_origin(python_type) +# args = get_args(python_type) + +# if origin is list: +# # Handle list[T] +# if args: +# element_type = python_to_pyarrow_type(args[0]) +# return pa.list_(element_type) +# else: +# return pa.list_(pa.large_string()) # default to list of strings + +# elif origin is dict: +# # Handle dict[K, V] - PyArrow uses map type +# if len(args) == 2: +# key_type = python_to_pyarrow_type(args[0]) +# value_type = python_to_pyarrow_type(args[1]) +# return pa.map_(key_type, value_type) +# else: +# # Otherwise default to using long string +# return pa.map_(pa.large_string(), pa.large_string()) + +# elif origin is UnionType: +# # Handle Optional[T] (Union[T, None]) +# if len(args) == 2 and type(None) in args: +# non_none_type = args[0] if args[1] is type(None) else args[1] +# return python_to_pyarrow_type(non_none_type) + +# # Default fallback +# if not strict: +# logger.warning(f"Unsupported type {python_type}, defaulting to large_string") +# return pa.large_string() +# else: +# raise TypeError(f"Unsupported type {python_type} for PyArrow conversion. " +# "Set strict=False to allow fallback to large_string.") + +# def arrow_to_dicts(table: pa.Table) -> list[dict[str, Any]]: +# """ +# Convert Arrow table to dictionary or list of dictionaries. +# Returns a list of dictionaries (one per row) with column names as keys. +# Args: +# table: PyArrow Table to convert +# Returns: +# A list of dictionaries for multi-row tables. +# """ +# if len(table) == 0: +# return [] + +# # Multiple rows: return list of dicts (one per row) +# return [ +# {col_name: table.column(col_name)[i].as_py() for col_name in table.column_names} +# for i in range(len(table)) +# ] + +# def get_metadata_from_schema( +# schema: pa.Schema, metadata_field: bytes +# ) -> dict[str, str]: +# """ +# Extract metadata from Arrow schema fields. Metadata value will be utf-8 decoded. +# Args: +# schema: PyArrow Schema to extract metadata from +# metadata_field: Metadata field to extract (e.g., b'semantic_type') +# Returns: +# Dictionary mapping field names to their metadata values +# """ +# metadata = {} +# for field in schema: +# if field.metadata and metadata_field in field.metadata: +# metadata[field.name] = field.metadata[metadata_field].decode("utf-8") +# return metadata + +# def dict_to_arrow_table_with_metadata(data: dict, data_type_info: TypeSpec | None = None, metadata: dict | None = None): +# """ +# Convert a tag dictionary to PyArrow table with metadata on each column. + +# Args: +# tag: Dictionary with string keys and any Python data type values +# metadata_key: The metadata key to add to each column +# metadata_value: The metadata value to indicate this column came from tag +# """ +# if metadata is None: +# metadata = {} + +# if field_types is None: +# # First create the table to infer types +# temp_table = pa.Table.from_pylist([data]) + +# # Create new fields with metadata +# fields_with_metadata = [] +# for field in temp_table.schema: +# # Add metadata to each field +# field_metadata = metadata +# new_field = pa.field( +# field.name, field.type, nullable=field.nullable, metadata=field_metadata +# ) +# fields_with_metadata.append(new_field) + +# # Create schema with metadata +# schema_with_metadata = pa.schema(fields_with_metadata) + +# # Create the final table with the metadata-enriched schema +# table = pa.Table.from_pylist([tag], schema=schema_with_metadata) + +# return table + + +# # def get_columns_with_metadata( +# # df: pl.DataFrame, key: str, value: str | None = None +# # ) -> list[str]: +# # """Get column names with specific metadata using list comprehension. If value is given, only +# # columns matching that specific value for the desginated metadata key will be returned. +# # Otherwise, all columns that contains the key as metadata will be returned regardless of the value""" +# # return [ +# # col_name +# # for col_name, dtype in df.schema.items() +# # if hasattr(dtype, "metadata") +# # and (value is None or getattr(dtype, "metadata") == value) +# # ] diff --git a/src/orcapod/types/typespec.py b/src/orcapod/types/typespec_utils.py similarity index 83% rename from src/orcapod/types/typespec.py rename to src/orcapod/types/typespec_utils.py index eb5be89..0786d10 100644 --- a/src/orcapod/types/typespec.py +++ b/src/orcapod/types/typespec_utils.py @@ -1,8 +1,7 @@ # Library of functions for working with TypeSpecs and for extracting TypeSpecs from a function's signature - -from collections.abc import Callable, Collection, Sequence -from typing import get_origin, get_args +from collections.abc import Callable, Collection, Sequence, Mapping +from typing import get_origin, get_args, Any from .core import TypeSpec import inspect import logging @@ -213,3 +212,57 @@ def extract_function_typespecs( f"Type for return item '{key}' is not specified in output_types and has no type annotation in function signature." ) return param_info, inferred_output_types + + + +def get_typespec_from_dict(dict: Mapping) -> TypeSpec: + """ + Returns a TypeSpec for the given dictionary. + The TypeSpec is a mapping from field name to Python type. + """ + return {key: type(value) for key, value in dict.items()} + + +def get_compatible_type(type1: Any, type2: Any) -> Any: + if type1 is type2: + return type1 + if issubclass(type1, type2): + return type2 + if issubclass(type2, type1): + return type1 + raise TypeError(f"Types {type1} and {type2} are not compatible") + + +def union_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: + if left is None: + return right + if right is None: + return left + # Merge the two TypeSpecs but raise an error if conflicts in types are found + merged = dict(left) + for key, right_type in right.items(): + merged[key] = ( + get_compatible_type(merged[key], right_type) + if key in merged + else right_type + ) + return merged + +def intersection_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: + """ + Returns the intersection of two TypeSpecs, only returning keys that are present in both. + If a key is present in both TypeSpecs, the type must be the same. + """ + if left is None or right is None: + return None + # Find common keys and ensure types match + common_keys = set(left.keys()).intersection(set(right.keys())) + intersection = {} + for key in common_keys: + try: + intersection[key] = get_compatible_type(left[key], right[key]) + except TypeError: + # If types are not compatible, raise an error + raise TypeError(f"Type conflict for key '{key}': {left[key]} vs {right[key]}") + + return intersection \ No newline at end of file diff --git a/src/orcapod/utils/stream_utils.py b/src/orcapod/utils/stream_utils.py index 95703c8..5c5bb62 100644 --- a/src/orcapod/utils/stream_utils.py +++ b/src/orcapod/utils/stream_utils.py @@ -12,23 +12,6 @@ V = TypeVar("V") -def get_typespec(dict: Mapping) -> TypeSpec: - """ - Returns a TypeSpec for the given dictionary. - The TypeSpec is a mapping from field name to Python type. - """ - return {key: type(value) for key, value in dict.items()} - - -def get_compatible_type(type1: Any, type2: Any) -> Any: - if type1 is type2: - return type1 - if issubclass(type1, type2): - return type2 - if issubclass(type2, type1): - return type1 - raise TypeError(f"Types {type1} and {type2} are not compatible") - def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: merged = left.copy() @@ -43,39 +26,6 @@ def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: return merged -def union_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: - if left is None: - return right - if right is None: - return left - # Merge the two TypeSpecs but raise an error if conflicts in types are found - merged = dict(left) - for key, right_type in right.items(): - merged[key] = ( - get_compatible_type(merged[key], right_type) - if key in merged - else right_type - ) - return merged - -def intersection_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: - """ - Returns the intersection of two TypeSpecs, only returning keys that are present in both. - If a key is present in both TypeSpecs, the type must be the same. - """ - if left is None or right is None: - return None - # Find common keys and ensure types match - common_keys = set(left.keys()).intersection(set(right.keys())) - intersection = {} - for key in common_keys: - try: - intersection[key] = get_compatible_type(left[key], right[key]) - except TypeError: - # If types are not compatible, raise an error - raise TypeError(f"Type conflict for key '{key}': {left[key]} vs {right[key]}") - - return intersection def common_elements(*values) -> Collection[str]: diff --git a/tests/test_hashing/test_composite_hasher.py b/tests/test_hashing/test_composite_hasher.py deleted file mode 100644 index f92cfea..0000000 --- a/tests/test_hashing/test_composite_hasher.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python -"""Tests for the CompositeFileHasher implementation.""" - -from unittest.mock import patch - -import pytest - -from orcapod.hashing.legacy_core import hash_to_hex -from orcapod.hashing.file_hashers import BasicFileHasher, DefaultCompositeFileHasher -from orcapod.hashing.types import FileHasher, PacketHasher, PathSetHasher - - -# Custom implementation of hash_file for tests that doesn't check for file existence -def mock_hash_file(file_path, algorithm="sha256", buffer_size=65536) -> str: - """Mock implementation of hash_file that doesn't check for file existence.""" - # Simply return a deterministic hash based on the file path - return hash_to_hex(f"mock_file_hash_{file_path}_{algorithm}") - - -# Custom implementation of hash_pathset for tests that doesn't check for file existence -def mock_hash_pathset( - pathset, algorithm="sha256", buffer_size=65536, char_count=32, file_hasher=None -): - """Mock implementation of hash_pathset that doesn't check for file existence.""" - from collections.abc import Collection - from os import PathLike - from pathlib import Path - - # If file_hasher is None, we'll need to handle it differently - if file_hasher is None: - # Just return a mock hash for testing - if isinstance(pathset, (str, Path, PathLike)): - return f"mock_{pathset}" - return "mock_hash" - - # Handle dictionary case for nested paths - if isinstance(pathset, dict): - hash_dict = {} - for key, value in pathset.items(): - hash_dict[key] = mock_hash_pathset( - value, algorithm, buffer_size, char_count, file_hasher - ) - return hash_to_hex(str(hash_dict)) - - # Handle collection case (list, set, etc.) - if isinstance(pathset, Collection) and not isinstance( - pathset, (str, Path, PathLike) - ): - hash_list = [] - for item in pathset: - hash_list.append( - mock_hash_pathset(item, algorithm, buffer_size, char_count, file_hasher) - ) - return hash_to_hex(str(hash_list)) - - # Handle simple string or Path case - if isinstance(pathset, (str, Path, PathLike)): - if hasattr(file_hasher, "__self__"): # For bound methods - return file_hasher(str(pathset)) - else: - return file_hasher(str(pathset)) - - return "mock_hash" - - -# Custom implementation of hash_packet for tests that doesn't check for file existence -def mock_hash_packet( - packet, - algorithm="sha256", - buffer_size=65536, - char_count=32, - prefix_algorithm=True, - pathset_hasher=None, -): - """Mock implementation of hash_packet that doesn't check for file existence.""" - # Create a simple hash based on the packet structure - hash_value = hash_to_hex(str(packet)) - - # Format it like the real function would - if prefix_algorithm and algorithm: - return ( - f"{algorithm}-{hash_value[: char_count if char_count else len(hash_value)]}" - ) - else: - return hash_value[: char_count if char_count else len(hash_value)] - - -@pytest.fixture(autouse=True) -def patch_hash_functions(): - """Patch the hash functions in the core module for all tests.""" - with ( - patch("orcapod.hashing.core.hash_file", side_effect=mock_hash_file), - patch("orcapod.hashing.core.hash_pathset", side_effect=mock_hash_pathset), - patch("orcapod.hashing.core.hash_packet", side_effect=mock_hash_packet), - ): - yield - - -def test_default_composite_hasher_implements_all_protocols(): - """Test that CompositeFileHasher implements all three protocols.""" - # Create a basic file hasher to be used within the composite hasher - file_hasher = BasicFileHasher() - - # Create the composite hasher - composite_hasher = DefaultCompositeFileHasher(file_hasher) - - # Verify it implements all three protocols - assert isinstance(composite_hasher, FileHasher) - assert isinstance(composite_hasher, PathSetHasher) - assert isinstance(composite_hasher, PacketHasher) - - -def test_default_composite_hasher_file_hashing(): - """Test CompositeFileHasher's file hashing functionality.""" - # We can use a mock path since our mocks don't require real files - file_path = "/path/to/mock_file.txt" - - # Create a custom mock file hasher - class MockFileHasher: - def hash_file(self, file_path): - return mock_hash_file(file_path) - - file_hasher = MockFileHasher() - composite_hasher = DefaultCompositeFileHasher(file_hasher) - - # Get hash from the composite hasher and directly from the file hasher - direct_hash = file_hasher.hash_file(file_path) - composite_hash = composite_hasher.hash_file(file_path) - - # The hashes should be identical - assert direct_hash == composite_hash - - -def test_default_composite_hasher_pathset_hashing(): - """Test CompositeFileHasher's path set hashing functionality.""" - - # Create a custom mock file hasher that doesn't check for file existence - class MockFileHasher: - def hash_file(self, file_path): - return mock_hash_file(file_path) - - file_hasher = MockFileHasher() - composite_hasher = DefaultCompositeFileHasher(file_hasher) - - # Simple path set with non-existent paths - pathset = ["/path/to/file1.txt", "/path/to/file2.txt"] - - # Hash the pathset - result = composite_hasher.hash_pathset(pathset) - - # The result should be a string hash - assert isinstance(result, str) - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/test_store/test_transfer_data_store.py b/tests/test_store/test_transfer_data_store.py index 6fd2add..191da89 100644 --- a/tests/test_store/test_transfer_data_store.py +++ b/tests/test_store/test_transfer_data_store.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for TransferDataStore.""" -import json from pathlib import Path import pytest diff --git a/tests/test_types/test_inference/test_extract_function_data_types.py b/tests/test_types/test_inference/test_extract_function_data_types.py index e96fd9c..8ae1ea5 100644 --- a/tests/test_types/test_inference/test_extract_function_data_types.py +++ b/tests/test_types/test_inference/test_extract_function_data_types.py @@ -11,7 +11,7 @@ import pytest from collections.abc import Collection -from orcapod.types.typespec import extract_function_typespecs +from orcapod.types.typespec_utils import extract_function_typespecs class TestExtractFunctionDataTypes: From a3ba1723d40c0cb8b16d95567ba16eacaf6b2a1f Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 00:52:46 +0000 Subject: [PATCH 26/57] feat: add field source tracking --- src/orcapod/core/base.py | 2 +- src/orcapod/core/pod.py | 2 +- src/orcapod/pipeline/wrappers.py | 200 ++++++++++++-------- src/orcapod/types/__init__.py | 4 +- src/orcapod/types/packets.py | 42 ++-- src/orcapod/types/schemas.py | 139 ++++++++++---- src/orcapod/types/semantic_type_handlers.py | 44 ++--- src/orcapod/types/semantic_type_registry.py | 23 ++- 8 files changed, 278 insertions(+), 178 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 9a30873..7c9a299 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -10,7 +10,7 @@ from orcapod.hashing import ContentIdentifiableBase from orcapod.types import Packet, Tag, TypeSpec -from orcapod.types.typespec import get_typespec_from_dict +from orcapod.types.typespec_utils import get_typespec_from_dict import logging diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index eb880b4..4271887 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -9,7 +9,7 @@ from orcapod.types import Packet, Tag, TypeSpec, default_registry from orcapod.types.typespec_utils import extract_function_typespecs -from orcapod.types.semantic_type_registry import PacketConverter +from orcapod.types.packets import PacketConverter from orcapod.hashing import ( FunctionInfoExtractor, diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index 4396223..c12f40a 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -1,15 +1,14 @@ from orcapod.core.pod import Pod, FunctionPod from orcapod.core import SyncStream, Source, Kernel from orcapod.store import ArrowDataStore -from orcapod.types import Tag, Packet, TypeSpec, default_registry +from orcapod.types import Tag, Packet, PacketLike, TypeSpec, default_registry from orcapod.types.typespec_utils import get_typespec_from_dict, union_typespecs, extract_function_typespecs -from orcapod.types.semantic_type_registry import create_arrow_table_with_meta +from orcapod.types.semantic_type_registry import SemanticTypeRegistry +from orcapod.types import packets, schemas from orcapod.hashing import ObjectHasher, ArrowHasher from orcapod.hashing.defaults import get_default_object_hasher, get_default_arrow_hasher from typing import Any, Literal from collections.abc import Collection, Iterator -from orcapod.types.semantic_type_registry import TypeRegistry -from orcapod.types.packet_converter import PacketConverter import pyarrow as pa import polars as pl from orcapod.core.streams import SyncStreamFromGenerator @@ -18,12 +17,15 @@ logger = logging.getLogger(__name__) +def get_tag_typespec(tag: Tag) -> dict[str, type]: + return {k: str for k in tag} class PolarsSource(Source): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str] | None = None): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str], packet_keys: Collection[str]|None = None): self.df = df self.tag_keys = tag_keys + self.packet_keys = packet_keys def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if len(streams) != 0: @@ -31,19 +33,25 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: "PolarsSource does not support forwarding streams. " "It generates its own stream from the DataFrame." ) - return PolarsStream(self.df, self.tag_keys) + return PolarsStream(self.df, self.tag_keys, self.packet_keys) class PolarsStream(SyncStream): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str]): + def __init__(self, df: pl.DataFrame, tag_keys: Collection[str], packet_keys: Collection[str] | None = None): self.df = df - self.tag_keys = tag_keys + self.tag_keys = tuple(tag_keys) + self.packet_keys = tuple(packet_keys) if packet_keys is not None else None def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - for row in self.df.iter_rows(named=True): + df = self.df + # if self.packet_keys is not None: + # df = df.select(self.tag_keys + self.packet_keys) + for row in df.iter_rows(named=True): tag = {key: row[key] for key in self.tag_keys} - packet = {key: val for key, val in row.items() if key not in self.tag_keys} - yield tag, Packet(packet) + packet = {key: val for key, val in row.items() if key not in self.tag_keys and not key.startswith("_source_info_")} + # TODO: revisit and fix this rather hacky implementation + source_info = {key.removeprefix("_source_info_"):val for key, val in row.items() if key.startswith("_source_info_")} + yield tag, Packet(packet, source_info=source_info) class EmptyStream(SyncStream): @@ -134,6 +142,13 @@ def claims_unique_tags( *resolved_streams, trigger_run=trigger_run ) + + + def post_call(self, tag: Tag, packet: Packet) -> None: ... + + def output_iterator_completion_hook(self) -> None: ... + + class CachedKernelWrapper(KernelInvocationWrapper, Source): """ @@ -154,7 +169,7 @@ def __init__( output_store: ArrowDataStore, kernel_hasher: ObjectHasher | None = None, arrow_packet_hasher: ArrowHasher | None = None, - packet_type_registry: TypeRegistry | None = None, + packet_type_registry: SemanticTypeRegistry | None = None, **kwargs, ) -> None: super().__init__(kernel, input_streams, **kwargs) @@ -172,9 +187,7 @@ def __init__( packet_type_registry = default_registry self._packet_type_registry = packet_type_registry - self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) - self.tag_keys, self.packet_keys = self.keys(trigger_run=False) - self.output_converter = None + self.update_cached_values() self._cache_computed = False @@ -203,70 +216,75 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): def update_cached_values(self): self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) - self.output_converter = None + self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) + if self.tag_typespec is None or self.packet_typespec is None: + raise ValueError("Currently, cached kernel wrapper can only work with kernels that have typespecs defined.") + # TODO: clean up and make it unnecessary to convert packet typespec + packet_schema = schemas.PythonSchema(self.packet_typespec) + joined_typespec = union_typespecs(self.tag_typespec, packet_schema.with_source_info) + if joined_typespec is None: + raise ValueError( + "Joined typespec should not be None. " + "This may happen if the tag typespec and packet typespec are incompatible." + ) + # Add any additional fields to the output converter here + self.output_converter = packets.PacketConverter(joined_typespec, registry=self.registry, include_source_info=False) + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: logger.info(f"Returning cached outputs for {self}") if self.df is not None: - return PolarsStream(self.df, tag_keys=self.tag_keys) + if self.tag_keys is None: + raise ValueError( + "CachedKernelWrapper has no tag keys defined, cannot return PolarsStream" + ) + source_info_sig = ':'.join(self.source_info) + return PolarsStream(self.df, tag_keys=self.tag_keys, packet_keys=self.packet_keys) else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.packet_keys) resolved_streams = self.resolve_input_streams(*streams) - output_stream = self.kernel.forward(*resolved_streams, **kwargs) - tag_typespec, packet_typespec = output_stream.types(trigger_run=False) - if tag_typespec is not None and packet_typespec is not None: - joined_type = union_typespecs(tag_typespec, packet_typespec) - assert joined_type is not None, "Joined typespec should not be None" - all_type = dict(joined_type) - for k in packet_typespec: - all_type[f'_source_{k}'] = str - # - self.output_converter = PacketConverter(all_type, registry=self.registry) - # Cache the output stream of the underlying kernel # If an entry with same tag and packet already exists in the output store, # it will not be added again, thus avoiding duplicates. def generator() -> Iterator[tuple[Tag, Packet]]: logger.info(f"Computing and caching outputs for {self}") for tag, packet in output_stream: - merged_info = {**tag, **packet} - # add entries for source_info - for k, v in packet.source_info.items(): - merged_info[f'_source_{k}'] = v - - if self.output_converter is None: - # TODO: cleanup logic here - joined_type = get_typespec_from_dict(merged_info) - assert joined_type is not None, "Joined typespec should not be None" - all_type = dict(joined_type) - for k in packet: - all_type[f'_source_{k}'] = str - self.output_converter = PacketConverter( - all_type, registry=self.registry - ) - - # add entries for source_info - for k, v in packet.source_info.items(): - merged_info[f'_source_{k}'] = v - - output_table = self.output_converter.to_arrow_table(merged_info) - # TODO: revisit this logic - output_id = self.arrow_hasher.hash_table(output_table) - if not self.output_store.get_record(*self.source_info, output_id): - self.output_store.add_record( - *self.source_info, - output_id, - output_table, - ) + self.post_call(tag, packet) yield tag, packet - self._cache_computed = True + self.output_iterator_completion_hook() + + logger.info(f"Results cached for {self}") + self._cache_computed = True return SyncStreamFromGenerator(generator) + def post_call(self, tag: Tag, packet: Packet) -> None: + # Cache the output stream of the underlying kernel + # If an entry with same tag and packet already exists in the output store, + # it will not be added again, thus avoiding duplicates. + merged_info = {**tag, **packet.get_composite()} + output_table = self.output_converter.from_python_packet_to_arrow_table(merged_info) + # TODO: revisit this logic + output_id = self.arrow_hasher.hash_table(output_table) + if not self.output_store.get_record(*self.source_info, output_id): + self.output_store.add_record( + *self.source_info, + output_id, + output_table, + ) + + def output_iterator_completion_hook(self) -> None: + """ + Hook to be called when the generator is completed. + """ + logger.info(f"Results cached for {self}") + self._cache_computed = True + + @property def lazy_df(self) -> pl.LazyFrame | None: return self.output_store.get_all_records_as_polars(*self.source_info) @@ -333,7 +351,7 @@ def __init__( error_handling: Literal["raise", "ignore", "warn"] = "raise", object_hasher: ObjectHasher | None = None, arrow_hasher: ArrowHasher | None = None, - registry: TypeRegistry | None = None, + registry: SemanticTypeRegistry | None = None, **kwargs, ) -> None: super().__init__( @@ -391,11 +409,11 @@ def arrow_hasher(self, arrow_hasher: ArrowHasher | None = None): self.update_cached_values() @property - def registry(self) -> TypeRegistry: + def registry(self) -> SemanticTypeRegistry: return self._registry @registry.setter - def registry(self, registry: TypeRegistry | None = None): + def registry(self, registry: SemanticTypeRegistry | None = None): if registry is None: registry = default_registry self._registry = registry @@ -405,11 +423,29 @@ def registry(self, registry: TypeRegistry | None = None): def update_cached_values(self) -> None: self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) self.tag_keys, self.output_keys = self.keys(trigger_run=False) + if self.tag_keys is None or self.output_keys is None: + raise ValueError( + "Currently, cached function pod wrapper can only work with function pods that have keys defined." + ) + self.all_keys = tuple(self.tag_keys) + tuple(self.output_keys) + self.tag_typespec, self.output_typespec = self.types(trigger_run=False) + if self.tag_typespec is None or self.output_typespec is None: + raise ValueError( + "Currently, cached function pod wrapper can only work with function pods that have typespecs defined." + ) self.input_typespec, self.output_typespec = ( self.function_pod.get_function_typespecs() ) - self.input_converter = PacketConverter(self.input_typespec, self.registry) - self.output_converter = PacketConverter(self.output_typespec, self.registry) + + self.input_converter = packets.PacketConverter(self.input_typespec, self.registry, include_source_info=False) + self.output_converter = packets.PacketConverter(self.output_typespec, self.registry, include_source_info=True) + + input_packet_source_typespec = {f'_source_info_{k}': str for k in self.input_typespec} + + # prepare typespec for tag record: __packet_key, tag, input packet source_info, + tag_record_typespec = {"__packet_key": str, **self.tag_typespec, **input_packet_source_typespec} + self.tag_record_converter = packets.PacketConverter(tag_record_typespec, self.registry, include_source_info=False) + def reset_cache(self): self._cache_computed = False @@ -425,14 +461,17 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: logger.info(f"Returning cached outputs for {self}") if self.df is not None: - return PolarsStream(self.df, self.tag_keys) + if self.tag_keys is None: + raise ValueError("Tag keys are not set, cannot return PolarsStream") + + return PolarsStream(self.df, self.tag_keys, packet_keys=self.output_keys) else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.output_keys) logger.info(f"Computing and caching outputs for {self}") return super().forward(*streams, **kwargs) def get_packet_key(self, packet: Packet) -> str: - return self.arrow_hasher.hash_table(self.input_converter.to_arrow_table(packet)) + return self.arrow_hasher.hash_table(self.input_converter.from_python_packet_to_arrow_table(packet)) @property def source_info(self): @@ -455,11 +494,9 @@ def _add_pipeline_record_with_packet_key(self, tag: Tag, packet_key: str, packet combined_info = dict(tag) # ensure we don't modify the original tag combined_info["__packet_key"] = packet_key for k, v in packet_source_info.items(): - combined_info[f'__{k}_source'] = v + combined_info[f'_source_info_{k}'] = v - # TODO: consider making this more efficient - # convert tag to arrow table - columns are labeled with metadata source=tag - table = create_arrow_table_with_meta(combined_info, {"source": "tag"}) + table = self.tag_record_converter.from_python_packet_to_arrow_table(combined_info) entry_hash = self.arrow_hasher.hash_table(table) @@ -492,7 +529,7 @@ def _retrieve_memoized_with_packet_key(self, packet_key: str) -> Packet | None: ) if arrow_table is None: return None - packets = self.function_pod.output_converter.from_arrow_table(arrow_table) + packets = self.output_converter.from_arrow_table_to_python_packets(arrow_table) # since memoizing single packet, it should only contain one packet assert len(packets) == 1, ( f"Memoizing single packet return {len(packets)} packets!" @@ -509,10 +546,10 @@ def memoize( Returns the memoized packet. """ logger.debug("Memoizing packet") - return self._memoize_with_packet_key(self.get_packet_key(packet), output_packet) + return self._memoize_with_packet_key(self.get_packet_key(packet), output_packet.get_composite()) def _memoize_with_packet_key( - self, packet_key: str, output_packet: Packet + self, packet_key: str, output_packet: PacketLike ) -> Packet: """ Memoize the output packet in the data store, looking up by packet key. @@ -521,11 +558,11 @@ def _memoize_with_packet_key( logger.debug(f"Memoizing packet with key {packet_key}") # TODO: this logic goes through the entire store and retrieve cycle with two conversions # consider simpler alternative - packets = self.output_converter.from_arrow_table( + packets = self.output_converter.from_arrow_table_to_python_packets( self.output_store.add_record( *self.source_info, packet_key, - self.output_converter.to_arrow_table(output_packet), + self.output_converter.from_python_packet_to_arrow_table(output_packet), ) ) # since passed in a single packet, it should only return a single packet @@ -535,9 +572,7 @@ def _memoize_with_packet_key( packet = packets[0] # TODO: reconsider the right place to attach this information # attach provenance information - packet_source_id = ":".join(self.source_info + (packet_key,)) - source_info = {k: f'{packet_source_id}:{k}' for k in packet} - return Packet(packet, source_info=source_info) + return Packet(packet) def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: @@ -567,6 +602,10 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: if output_packet is not None and not self.skip_memoization: # output packet may be modified by the memoization process # e.g. if the output is a file, the path may be changed + # add source info to the output packet + source_info = {k: '-'.join(self.source_info) + "-" + packet_key for k in output_packet.source_info} + # TODO: fix and make this not access protected field directly + output_packet.source_info = source_info output_packet = self._memoize_with_packet_key(packet_key, output_packet) # type: ignore if output_packet is None: @@ -593,7 +632,7 @@ def get_all_tags(self, with_packet_id: bool = False) -> pl.LazyFrame | None: return data.drop("__packet_key") if data is not None else None return data - def get_all_entries_with_tags(self) -> pl.LazyFrame | None: + def get_all_entries_with_tags(self, keep_hidden_fields: bool = False) -> pl.LazyFrame | None: """ Retrieve all entries from the tag store with their associated tags. Returns a DataFrame with columns for tag and packet key. @@ -612,9 +651,12 @@ def get_all_entries_with_tags(self) -> pl.LazyFrame | None: if result_packets is None: return None - return pl.concat([tag_records, result_packets], how="horizontal").drop( + pl_df = pl.concat([tag_records, result_packets], how="horizontal").drop( ["__packet_key"] ) + if not keep_hidden_fields: + pl_df = pl_df.select(self.all_keys) + return pl_df @property def df(self) -> pl.DataFrame | None: diff --git a/src/orcapod/types/__init__.py b/src/orcapod/types/__init__.py index a4615f5..03a3b4b 100644 --- a/src/orcapod/types/__init__.py +++ b/src/orcapod/types/__init__.py @@ -1,5 +1,5 @@ from .core import Tag, PathLike, PathSet, PodFunction, TypeSpec -from .packets import Packet +from .packets import Packet, PacketLike from .semantic_type_registry import SemanticTypeRegistry from .semantic_type_handlers import PathHandler, UUIDHandler, DateTimeHandler from . import semantic_type_handlers @@ -20,7 +20,7 @@ "default_registry", "Tag", "Packet", - "PacketLike" + "PacketLike", "TypeSpec", "PathLike", "PathSet", diff --git a/src/orcapod/types/packets.py b/src/orcapod/types/packets.py index a8d8775..4a3b192 100644 --- a/src/orcapod/types/packets.py +++ b/src/orcapod/types/packets.py @@ -36,6 +36,16 @@ def typespec(self) -> TypeSpec: @property def source_info(self) -> dict[str, str | None]: return {key: self._source_info.get(key, None) for key in self.keys()} + + @source_info.setter + def source_info(self, source_info: Mapping[str, str | None]): + self._source_info = {key: value for key, value in source_info.items() if value is not None} + + def get_composite(self) -> PacketLike: + composite = self.copy() + for k, v in self.source_info.items(): + composite[f"_source_info_{k}"] = v + return composite @@ -69,15 +79,20 @@ def __init__(self, *args, semantic_schema: schemas.SemanticSchema | None = None, source_info = {} self.source_info = source_info + def get_composite(self) -> dict[str, Any]: + composite = self.copy() + for k, v in self.source_info.items(): + composite[f"_source_info_{k}"] = v + return composite class PacketConverter: - def __init__(self, python_schema: schemas.PythonSchema, registry: SemanticTypeRegistry, include_source_info: bool = True): - self.python_schema = python_schema + def __init__(self, typespec: TypeSpec, registry: SemanticTypeRegistry, include_source_info: bool = True): + self.typespec = typespec self.registry = registry - self.semantic_schema = schemas.from_python_schema_to_semantic_schema( - python_schema, registry + self.semantic_schema = schemas.from_typespec_to_semantic_schema( + typespec, registry ) self.include_source_info = include_source_info @@ -90,7 +105,7 @@ def __init__(self, python_schema: schemas.PythonSchema, registry: SemanticTypeRe self.key_handlers: dict[str, TypeHandler] = {} - self.expected_key_set = set(self.python_schema.keys()) + self.expected_key_set = set(self.typespec.keys()) for key, (_, semantic_type) in self.semantic_schema.items(): if semantic_type is None: @@ -168,16 +183,11 @@ def from_semantic_packet_to_arrow_table(self, semantic_packet: SemanticPacket) - Returns: Arrow table representation of the packet """ - arrays = [] - for field in self.arrow_schema: - value = semantic_packet.get(field.name, None) - arrays.append(pa.array([value], type=field.type)) - if self.include_source_info: - for field, value in semantic_packet.source_info.items(): - arrays.append(pa.array([value], type=pa.large_string())) + return pa.Table.from_pylist([semantic_packet.get_composite()], schema=self.arrow_schema) + else: + return pa.Table.from_pylist([semantic_packet], schema=self.arrow_schema) - return pa.Table.from_arrays(arrays, schema=self.arrow_schema) def from_arrow_table_to_semantic_packets(self, arrow_table: pa.Table) -> Collection[SemanticPacket]: """Convert an Arrow table to a semantic packet. @@ -198,8 +208,8 @@ def from_arrow_table_to_semantic_packets(self, arrow_table: pa.Table) -> Collect semantic_packets = [] for all_packet_content in semantic_packets_contents: packet_content = {k: v for k, v in all_packet_content.items() if k in self.expected_key_set} - source_info = {k.strip('_source_info_'): v for k, v in all_packet_content.items() if k.startswith('_source_info_')} - semantic_packets.append(SemanticPacket(packet_content, _semantic_schema=self.semantic_schema, _source_info=source_info)) + source_info = {k.removeprefix('_source_info_'): v for k, v in all_packet_content.items() if k.startswith('_source_info_')} + semantic_packets.append(SemanticPacket(packet_content, semantic_schema=self.semantic_schema, source_info=source_info)) return semantic_packets @@ -213,7 +223,7 @@ def from_semantic_packet_to_python_packet(self, semantic_packet: SemanticPacket) Python packet representation of the semantic packet """ # Validate packet keys - python_packet = Packet(semantic_packet, typespec=self.python_schema, source_info=semantic_packet.source_info) + python_packet = Packet(semantic_packet, typespec=self.typespec, source_info=semantic_packet.source_info) packet_keys = set(python_packet.keys()) self._check_key_consistency(packet_keys) diff --git a/src/orcapod/types/schemas.py b/src/orcapod/types/schemas.py index 4f78ca5..19e8a3b 100644 --- a/src/orcapod/types/schemas.py +++ b/src/orcapod/types/schemas.py @@ -39,27 +39,37 @@ def arrow_to_python_type(arrow_type: pa.DataType) -> type: -# class PythonSchema(dict[str, type]): -# """ -# A schema for Python data types, mapping string keys to Python types. +class PythonSchema(dict[str, type]): + """ + A schema for Python data types, mapping string keys to Python types. -# This is used to define the expected structure of data packets in OrcaPod. + This is used to define the expected structure of data packets in OrcaPod. -# Attributes -# ---------- -# keys : str -# The keys of the schema. -# values : type -# The types corresponding to each key. + Attributes + ---------- + keys : str + The keys of the schema. + values : type + The types corresponding to each key. -# Examples -# -------- -# >>> schema = PythonSchema(name=str, age=int) -# >>> print(schema) -# {'name': , 'age': } -# """ + Examples + -------- + >>> schema = PythonSchema(name=str, age=int) + >>> print(schema) + {'name': , 'age': } + """ + @property + def with_source_info(self) -> dict[str, type]: + """ + Get the schema with source info fields included. + + Returns + ------- + dict[str, type|None] + A new schema including source info fields. + """ + return {**self, **{f'_source_info_{k}': str for k in self.keys()}} -PythonSchema = TypeSpec class SemanticSchema(dict[str, tuple[type, str|None]]): @@ -112,10 +122,42 @@ def get_semantic_type(self, key: str) -> str | None: The semantic type associated with the key, or None if not found. """ return self.get(key, (None, None))[1] + + @property + def storage_schema(self) -> PythonSchema: + """ + Get the storage schema, which is a PythonSchema representation of the semantic schema. + + Returns + ------- + PythonSchema + A new schema mapping keys to Python types. + """ + return PythonSchema({k: v[0] for k, v in self.items()}) + + + @property + def storage_schema_with_source_info(self) -> dict[str, type]: + """ + Get the storage schema with source info fields included. + + Returns + ------- + dict[str, type] + A new schema including source info fields. + + Examples + -------- + >>> semantic_schema = SemanticSchema(name=(str, 'name'), age=(int, None)) + >>> storage_schema = semantic_schema.storage_schema_with_source_info + >>> print(storage_schema) + {'name': , 'age': , '_source_info_name': , '_source_info_age': } + """ + return self.storage_schema.with_source_info -def from_python_schema_to_semantic_schema( - python_schema: PythonSchema, +def from_typespec_to_semantic_schema( + typespec: TypeSpec, semantic_type_registry: SemanticTypeRegistry, ) -> SemanticSchema: """ @@ -123,8 +165,8 @@ def from_python_schema_to_semantic_schema( Parameters ---------- - python_schema : PythonSchema - The schema to convert, mapping keys to Python types. + typespec : TypeSpec + The typespec to convert, mapping keys to Python types. semantic_type_registry : SemanticTypeRegistry The registry containing semantic type information. @@ -135,13 +177,13 @@ def from_python_schema_to_semantic_schema( Examples -------- - >>> python_schema = PythonSchema(name=str, age=int) - >>> semantic_schema = from_python_schema_to_semantic_schema(python_schema, registry) + >>> typespec: TypeSpec = dict(name=str, age=int) + >>> semantic_schema = from_typespec_to_semanticn_schema(typespec, registry) >>> print(semantic_schema) {'name': (, None), 'age': (, None)} """ semantic_schema = {} - for key, python_type in python_schema.items(): + for key, python_type in typespec.items(): if python_type in semantic_type_registry: type_info = semantic_type_registry.get_type_info(python_type) assert type_info is not None, f"Type {python_type} should be found in the registry as `in` returned True" @@ -176,13 +218,13 @@ def from_semantic_schema_to_python_schema( >>> print(python_schema) {'name': , 'age': } """ - python_schema = {} + python_schema_content = {} for key, (python_type, semantic_type) in semantic_schema.items(): if semantic_type is not None: # If the semantic type is registered, use the corresponding Python type python_type = semantic_type_registry.get_python_type(semantic_type) - python_schema[key] = python_type - return python_schema + python_schema_content[key] = python_type + return PythonSchema(python_schema_content) def from_semantic_schema_to_arrow_schema( semantic_schema: SemanticSchema, @@ -223,7 +265,6 @@ def from_semantic_schema_to_arrow_schema( def from_arrow_schema_to_semantic_schema( arrow_schema: pa.Schema, - semantic_type_registry: SemanticTypeRegistry | None = None, ) -> SemanticSchema: """ Convert an Arrow schema to a semantic schema. @@ -253,15 +294,39 @@ def from_arrow_schema_to_semantic_schema( continue semantic_type = field.metadata.get(b'semantic_type', None) semantic_type = semantic_type.decode() if semantic_type else None - if semantic_type: - if semantic_type_registry is None: - raise ValueError("Semantic type registry must be provided for semantic types") - python_type = semantic_type_registry.get_python_type(semantic_type) - if python_type is None: - raise ValueError(f"Semantic type '{semantic_type}' is not registered in the registry") - else: - python_type = arrow_to_python_type(field.type) - + python_type = arrow_to_python_type(field.type) semantic_schema[field.name] = (python_type, semantic_type) return SemanticSchema(semantic_schema) +def from_typespec_to_arrow_schema(typespec: TypeSpec, + semantic_type_registry: SemanticTypeRegistry, include_source_info: bool = True) -> pa.Schema: + semantic_schema = from_typespec_to_semantic_schema(typespec, semantic_type_registry) + return from_semantic_schema_to_arrow_schema(semantic_schema, include_source_info=include_source_info) + + +def from_arrow_schema_to_python_schema( + arrow_schema: pa.Schema, + semantic_type_registry: SemanticTypeRegistry, +) -> PythonSchema: + """ + Convert an Arrow schema to a Python schema. + + Parameters + ---------- + arrow_schema : pa.Schema + The schema to convert, containing fields with metadata. + + Returns + ------- + PythonSchema + A new schema mapping keys to Python types. + + Examples + -------- + >>> arrow_schema = pa.schema([pa.field('name', pa.string()), pa.field('age', pa.int64())]) + >>> python_schema = from_arrow_schema_to_python_schema(arrow_schema) + >>> print(python_schema) + {'name': , 'age': } + """ + semantic_schema = from_arrow_schema_to_semantic_schema(arrow_schema) + return from_semantic_schema_to_python_schema(semantic_schema, semantic_type_registry) \ No newline at end of file diff --git a/src/orcapod/types/semantic_type_handlers.py b/src/orcapod/types/semantic_type_handlers.py index a15f9d5..b3bc70c 100644 --- a/src/orcapod/types/semantic_type_handlers.py +++ b/src/orcapod/types/semantic_type_handlers.py @@ -12,8 +12,8 @@ class PathHandler: def python_type(self) -> type: return Path - def storage_type(self) -> pa.DataType: - return pa.string() + def storage_type(self) -> type: + return str def python_to_storage(self, value: Path) -> str: return str(value) @@ -28,8 +28,8 @@ class UUIDHandler: def python_type(self) -> type: return UUID - def storage_type(self) -> pa.DataType: - return pa.string() + def storage_type(self) -> type: + return str def python_to_storage(self, value: UUID) -> str: return str(value) @@ -44,8 +44,8 @@ class DecimalHandler: def python_type(self) -> type: return Decimal - def storage_type(self) -> pa.DataType: - return pa.string() + def storage_type(self) -> type: + return str def python_to_storage(self, value: Decimal) -> str: return str(value) @@ -57,34 +57,14 @@ def storage_to_python(self, value: str) -> Decimal | None: class SimpleMappingHandler: """Handler for basic types that map directly to Arrow.""" - def __init__(self, python_type: type, arrow_type: pa.DataType): + def __init__(self, python_type: type): self._python_type = python_type - self._arrow_type = arrow_type def python_type(self) -> type: return self._python_type - def storage_type(self) -> pa.DataType: - return self._arrow_type - - def python_to_storage(self, value: Any) -> Any: - return value # Direct mapping - - def storage_to_python(self, value: Any) -> Any: - return value # Direct mapping - - -class DirectArrowHandler: - """Handler for types that map directly to Arrow without conversion.""" - - def __init__(self, arrow_type: pa.DataType): - self._arrow_type = arrow_type - - def python_type(self) -> type: - return self._arrow_type - - def storage_type(self) -> pa.DataType: - return self._arrow_type + def storage_type(self) -> type: + return self._python_type def python_to_storage(self, value: Any) -> Any: return value # Direct mapping @@ -97,10 +77,10 @@ class DateTimeHandler: """Handler for datetime objects.""" def python_type(self) -> type: - return (datetime, date, time) # Handles multiple related types + return datetime - def storage_type(self) -> pa.DataType: - return pa.timestamp("us") # Store everything as timestamp + def storage_type(self) -> type: + return datetime def python_to_storage(self, value: datetime | date | time) -> Any: if isinstance(value, datetime): diff --git a/src/orcapod/types/semantic_type_registry.py b/src/orcapod/types/semantic_type_registry.py index d954891..d5a677f 100644 --- a/src/orcapod/types/semantic_type_registry.py +++ b/src/orcapod/types/semantic_type_registry.py @@ -1,11 +1,6 @@ -from collections.abc import Callable, Collection, Sequence, Mapping import logging -from optparse import Values -from typing import Any, get_origin, get_args -from types import UnionType import pyarrow as pa -from orcapod.types.packets import Packet, PacketLike -from .core import TypeHandler, TypeSpec +from .core import TypeHandler from dataclasses import dataclass # This mapping is expected to be stable @@ -77,16 +72,21 @@ def get_python_type(self, semantic_type: str) -> type | None: """Get Python type for a semantic type.""" return self._semantic_to_python_lut.get(semantic_type) - + def lookup_handler_info(self, python_type: type) -> tuple[TypeHandler, str] | None: + """Lookup handler info for a Python type.""" + for registered_type, (handler, semantic_type) in self._handlers.items(): + if issubclass(python_type, registered_type): + return (handler, semantic_type) + return None def get_semantic_type(self, python_type: type) -> str | None: """Get semantic type for a Python type.""" - handler_info = self._handlers.get(python_type) + handler_info = self.lookup_handler_info(python_type) return handler_info[1] if handler_info else None def get_handler(self, python_type: type) -> TypeHandler | None: """Get handler for a Python type.""" - handler_info = self._handlers.get(python_type) + handler_info = self.lookup_handler_info(python_type) return handler_info[0] if handler_info else None def get_handler_by_semantic_type(self, semantic_type: str) -> TypeHandler | None: @@ -110,7 +110,10 @@ def get_type_info(self, python_type: type) -> TypeInfo | None: def __contains__(self, python_type: type) -> bool: """Check if a Python type is registered.""" - return python_type in self._handlers + for registered_type in self._handlers: + if issubclass(python_type, registered_type): + return True + return False From d3b66de700871a6b0b2c6166ba0fe18e4613db2c Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 01:24:14 +0000 Subject: [PATCH 27/57] feat: support map and join on packets with source info --- src/orcapod/core/operators.py | 14 ++--------- src/orcapod/types/packets.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index c68f34f..c26dc2d 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -262,11 +262,7 @@ def generator() -> Iterator[tuple[Tag, Packet]]: for left_tag, left_packet in left_stream_buffered: for right_tag, right_packet in right_stream_buffered: if (joined_tag := join_tags(left_tag, right_tag)) is not None: - if not check_packet_compatibility(left_packet, right_packet): - raise ValueError( - f"Packets are not compatible: {left_packet} and {right_packet}" - ) - yield joined_tag, Packet({**left_packet, **right_packet}) + yield joined_tag, left_packet.join(right_packet) return SyncStreamFromGenerator(generator) @@ -399,13 +395,7 @@ def forward(self, *streams: SyncStream) -> SyncStream: def generator(): for tag, packet in stream: - if self.drop_unmapped: - packet = Packet({ - v: packet[k] for k, v in self.key_map.items() if k in packet - }) - else: - packet = Packet({self.key_map.get(k, k): v for k, v in packet.items()}) - yield tag, packet + yield tag, packet.map_keys(self.key_map, self.drop_unmapped) return SyncStreamFromGenerator(generator) diff --git a/src/orcapod/types/packets.py b/src/orcapod/types/packets.py index 4a3b192..a6621ee 100644 --- a/src/orcapod/types/packets.py +++ b/src/orcapod/types/packets.py @@ -46,7 +46,54 @@ def get_composite(self) -> PacketLike: for k, v in self.source_info.items(): composite[f"_source_info_{k}"] = v return composite + + def map_keys(self, mapping: Mapping[str, str], drop_unmapped: bool=False) -> 'Packet': + """ + Map the keys of the packet using the provided mapping. + + Args: + mapping: A dictionary mapping old keys to new keys. + + Returns: + A new Packet with keys mapped according to the provided mapping. + """ + if drop_unmapped: + new_content = { + v: self[k] for k, v in mapping.items() if k in self + } + new_typespec = { + v: self.typespec[k] for k, v in mapping.items() if k in self.typespec + } + new_source_info = { + v: self.source_info[k] for k, v in mapping.items() if k in self.source_info + } + else: + new_content = {mapping.get(k, k): v for k, v in self.items()} + new_typespec = {mapping.get(k, k): v for k, v in self.typespec.items()} + new_source_info = {mapping.get(k, k): v for k, v in self.source_info.items()} + return Packet(new_content, typespec=new_typespec, source_info=new_source_info) + + def join(self, other: 'Packet') -> 'Packet': + """ + Join another packet to this one, merging their keys and values. + + Args: + other: Another Packet to join with this one. + + Returns: + A new Packet with keys and values from both packets. + """ + # make sure there is no key collision + if not set(self.keys()).isdisjoint(other.keys()): + raise ValueError(f"Key collision detected: packets {self} and {other} have overlapping keys" + " and cannot be joined without losing information.") + + new_content = {**self, **other} + new_typespec = {**self.typespec, **other.typespec} + new_source_info = {**self.source_info, **other.source_info} + + return Packet(new_content, typespec=new_typespec, source_info=new_source_info) # a batch is a tuple of a tag and a list of packets From 0bafbaa08d6c534b5a1b53293d7ed5e5c0384e71 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 01:24:41 +0000 Subject: [PATCH 28/57] fix: keep all columns internally --- src/orcapod/core/streams.py | 2 ++ src/orcapod/pipeline/wrappers.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index 33f6b78..21060b1 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -96,6 +96,7 @@ def __iter__(self) -> Iterator[tuple[Tag, Packet]]: if not self.check_consistency: yield from self.generator_factory() + # TODO: add typespec handling def keys( self, *, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: @@ -103,3 +104,4 @@ def keys( return super().keys(trigger_run=trigger_run) # If the keys are already set, return them return self.tag_keys.copy(), self.packet_keys.copy() + \ No newline at end of file diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index c12f40a..e999714 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -460,11 +460,11 @@ def generator_completion_hook(self, n_computed: int) -> None: def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: logger.info(f"Returning cached outputs for {self}") - if self.df is not None: + lazy_df = self.get_all_entries_with_tags(keep_hidden_fields=True) + if lazy_df is not None: if self.tag_keys is None: raise ValueError("Tag keys are not set, cannot return PolarsStream") - - return PolarsStream(self.df, self.tag_keys, packet_keys=self.output_keys) + return PolarsStream(lazy_df.collect(), self.tag_keys, packet_keys=self.output_keys) else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.output_keys) logger.info(f"Computing and caching outputs for {self}") From 6321467e88cd4f7b4e5cc9ebaf55fe24dfb21498 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 03:54:42 +0000 Subject: [PATCH 29/57] wip: update legacy file related tests and rename to stores --- src/orcapod/__init__.py | 8 +- src/orcapod/core/pod_legacy.py | 2 +- src/orcapod/hashing/__init__.py | 8 +- src/orcapod/hashing/defaults.py | 6 +- src/orcapod/hashing/file_hashers.py | 104 ++-- src/orcapod/hashing/legacy_core.py | 4 +- src/orcapod/hashing/types.py | 60 +- src/orcapod/pipeline/__init__.py | 5 + src/orcapod/pipeline/pipeline.py | 15 +- src/orcapod/pipeline/wrappers.py | 31 +- src/orcapod/{store => stores}/__init__.py | 0 .../{store => stores}/arrow_data_stores.py | 49 +- .../stores/delta_table_arrow_data_store.py | 559 ++++++++++++++++++ .../{store => stores}/dict_data_stores.py | 26 +- src/orcapod/{store => stores}/file_ops.py | 0 .../optimized_memory_store.py | 0 .../{store => stores}/safe_dir_data_store.py | 0 .../{store => stores}/transfer_data_store.py | 16 +- src/orcapod/{store => stores}/types.py | 27 +- .../test_basic_composite_hasher.py | 20 +- tests/test_hashing/test_cached_file_hasher.py | 6 +- tests/test_hashing/test_hasher_factory.py | 86 +-- tests/test_hashing/test_hasher_parity.py | 12 +- .../test_legacy_composite_hasher.py | 156 +++++ tests/test_hashing/test_packet_hasher.py | 38 +- tests/test_hashing/test_path_set_hasher.py | 20 +- tests/test_store/test_dir_data_store.py | 28 +- tests/test_store/test_integration.py | 6 +- tests/test_store/test_noop_data_store.py | 4 +- tests/test_store/test_transfer_data_store.py | 8 +- 30 files changed, 1041 insertions(+), 263 deletions(-) create mode 100644 src/orcapod/pipeline/__init__.py rename src/orcapod/{store => stores}/__init__.py (100%) rename src/orcapod/{store => stores}/arrow_data_stores.py (98%) create mode 100644 src/orcapod/stores/delta_table_arrow_data_store.py rename src/orcapod/{store => stores}/dict_data_stores.py (95%) rename src/orcapod/{store => stores}/file_ops.py (100%) rename src/orcapod/{store => stores}/optimized_memory_store.py (100%) rename src/orcapod/{store => stores}/safe_dir_data_store.py (100%) rename src/orcapod/{store => stores}/transfer_data_store.py (90%) rename src/orcapod/{store => stores}/types.py (79%) create mode 100644 tests/test_hashing/test_legacy_composite_hasher.py diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index db457e9..ad00035 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -1,11 +1,12 @@ from .core import operators, sources, streams from .core.streams import SyncStreamFromLists, SyncStreamFromGenerator -from . import hashing, store +from . import hashing, stores from .core.operators import Join, MapPackets, MapTags, packet, tag from .core.pod import FunctionPod, function_pod from .core.sources import GlobSource -from .store import DirDataStore, SafeDirDataStore +from .stores import DirDataStore, SafeDirDataStore from .core.tracker import GraphTracker +from .pipeline import Pipeline DEFAULT_TRACKER = GraphTracker() DEFAULT_TRACKER.activate() @@ -13,7 +14,7 @@ __all__ = [ "hashing", - "store", + "stores", "pod", "operators", "streams", @@ -31,4 +32,5 @@ "DEFAULT_TRACKER", "SyncStreamFromLists", "SyncStreamFromGenerator", + "Pipeline", ] diff --git a/src/orcapod/core/pod_legacy.py b/src/orcapod/core/pod_legacy.py index 32c8efb..18099c6 100644 --- a/src/orcapod/core/pod_legacy.py +++ b/src/orcapod/core/pod_legacy.py @@ -16,7 +16,7 @@ from orcapod.core.base import Kernel from orcapod.core.operators import Join from orcapod.core.streams import SyncStream, SyncStreamFromGenerator -from orcapod.store import DataStore, NoOpDataStore +from orcapod.stores import DataStore, NoOpDataStore logger = logging.getLogger(__name__) diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index 2bdff2b..7aaf11b 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -17,22 +17,22 @@ ) from .types import ( FileContentHasher, - PacketHasher, + LegacyPacketHasher, ArrowHasher, ObjectHasher, StringCacher, FunctionInfoExtractor, - CompositeFileHasher, + LegacyCompositeFileHasher, ) from .content_identifiable import ContentIdentifiableBase __all__ = [ "FileContentHasher", - "PacketHasher", + "LegacyPacketHasher", "ArrowHasher", "StringCacher", "ObjectHasher", - "CompositeFileHasher", + "LegacyCompositeFileHasher", "FunctionInfoExtractor", "hash_file", "hash_pathset", diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 61539b5..8ba7c0b 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -1,7 +1,7 @@ # A collection of utility function that provides a "default" implementation of hashers. # This is often used as the fallback hasher in the library code. from orcapod.hashing.types import ( - CompositeFileHasher, + LegacyCompositeFileHasher, ArrowHasher, FileContentHasher, StringCacher, @@ -36,7 +36,7 @@ def get_default_arrow_hasher( return arrow_hasher -def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: +def get_default_composite_file_hasher(with_cache=True) -> LegacyCompositeFileHasher: if with_cache: # use unlimited caching string_cacher = InMemoryCacher(max_size=None) @@ -44,7 +44,7 @@ def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: return LegacyPathLikeHasherFactory.create_basic_legacy_composite() -def get_default_composite_file_hasher_with_cacher(cacher=None) -> CompositeFileHasher: +def get_default_composite_file_hasher_with_cacher(cacher=None) -> LegacyCompositeFileHasher: if cacher is None: cacher = InMemoryCacher(max_size=None) return LegacyPathLikeHasherFactory.create_cached_legacy_composite(cacher) diff --git a/src/orcapod/hashing/file_hashers.py b/src/orcapod/hashing/file_hashers.py index cd12e80..64f48f8 100644 --- a/src/orcapod/hashing/file_hashers.py +++ b/src/orcapod/hashing/file_hashers.py @@ -2,11 +2,12 @@ from orcapod.hashing.hash_utils import hash_file from orcapod.hashing.types import ( FileContentHasher, - PathSetHasher, StringCacher, - CompositeFileHasher, + LegacyFileHasher, + LegacyPathSetHasher, + LegacyCompositeFileHasher, ) -from orcapod.types import Packet, PathLike, PathSet +from orcapod.types import PacketLike, PathLike, PathSet class BasicFileHasher: @@ -51,7 +52,7 @@ def hash_file(self, file_path: PathLike) -> bytes: # ----------------Legacy implementations for backward compatibility----------------- -class LegacyFileHasher: +class LegacyDefaultFileHasher: def __init__( self, algorithm: str = "sha256", @@ -60,45 +61,65 @@ def __init__( self.algorithm = algorithm self.buffer_size = buffer_size - def hash_file(self, file_path: PathLike) -> bytes: - return bytes.fromhex( - legacy_core.hash_file( - file_path, algorithm=self.algorithm, buffer_size=self.buffer_size - ), + def hash_file(self, file_path: PathLike) -> str: + return legacy_core.hash_file( + file_path, algorithm=self.algorithm, buffer_size=self.buffer_size ) -class LegacyPathsetHasher: + +class LegacyCachedFileHasher: + """File hasher with caching.""" + + def __init__( + self, + file_hasher: LegacyFileHasher, + string_cacher: StringCacher, + ): + self.file_hasher = file_hasher + self.string_cacher = string_cacher + + def hash_file(self, file_path: PathLike) -> str: + cache_key = f"file:{file_path}" + cached_value = self.string_cacher.get_cached(cache_key) + if cached_value is not None: + return cached_value + + value = self.file_hasher.hash_file(file_path) + self.string_cacher.set_cached(cache_key, value) + return value + + + +class LegacyDefaultPathsetHasher: """Default pathset hasher that composes file hashing.""" def __init__( self, - file_hasher: FileContentHasher, + file_hasher: LegacyFileHasher, char_count: int | None = 32, ): self.file_hasher = file_hasher self.char_count = char_count def _hash_file_to_hex(self, file_path: PathLike) -> str: - return self.file_hasher.hash_file(file_path).hex() + return self.file_hasher.hash_file(file_path) - def hash_pathset(self, pathset: PathSet) -> bytes: + def hash_pathset(self, pathset: PathSet) -> str: """Hash a pathset using the injected file hasher.""" - return bytes.fromhex( - legacy_core.hash_pathset( + return legacy_core.hash_pathset( pathset, char_count=self.char_count, - file_hasher=self._hash_file_to_hex, # Inject the method + file_hasher=self.file_hasher.hash_file, # Inject the method ) - ) -class LegacyPacketHasher: +class LegacyDefaultPacketHasher: """Default packet hasher that composes pathset hashing.""" def __init__( self, - pathset_hasher: PathSetHasher, + pathset_hasher: LegacyPathSetHasher, char_count: int | None = 32, prefix: str = "", ): @@ -107,9 +128,9 @@ def __init__( self.prefix = prefix def _hash_pathset_to_hex(self, pathset: PathSet): - return self.pathset_hasher.hash_pathset(pathset).hex() + return self.pathset_hasher.hash_pathset(pathset) - def hash_packet(self, packet: Packet) -> str: + def hash_packet(self, packet: PacketLike) -> str: """Hash a packet using the injected pathset hasher.""" hash_str = legacy_core.hash_packet( packet, @@ -121,28 +142,28 @@ def hash_packet(self, packet: Packet) -> str: # Convenience composite implementation -class LegacyCompositeFileHasher: +class LegacyDefaultCompositeFileHasher: """Composite hasher that implements all interfaces.""" def __init__( self, - file_hasher: FileContentHasher, + file_hasher: LegacyFileHasher, char_count: int | None = 32, packet_prefix: str = "", ): self.file_hasher = file_hasher - self.pathset_hasher = LegacyPathsetHasher(self.file_hasher, char_count) - self.packet_hasher = LegacyPacketHasher( + self.pathset_hasher = LegacyDefaultPathsetHasher(self.file_hasher, char_count) + self.packet_hasher = LegacyDefaultPacketHasher( self.pathset_hasher, char_count, packet_prefix ) - def hash_file(self, file_path: PathLike) -> bytes: + def hash_file(self, file_path: PathLike) -> str: return self.file_hasher.hash_file(file_path) - def hash_pathset(self, pathset: PathSet) -> bytes: + def hash_pathset(self, pathset: PathSet) -> str: return self.pathset_hasher.hash_pathset(pathset) - def hash_packet(self, packet: Packet) -> str: + def hash_packet(self, packet: PacketLike) -> str: return self.packet_hasher.hash_packet(packet) @@ -155,11 +176,11 @@ def create_basic_legacy_composite( algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, - ) -> CompositeFileHasher: + ) -> LegacyCompositeFileHasher: """Create a basic composite hasher.""" - file_hasher = LegacyFileHasher(algorithm, buffer_size) + file_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) # use algorithm as the prefix for the packet hasher - return LegacyCompositeFileHasher( + return LegacyDefaultCompositeFileHasher( file_hasher, char_count, packet_prefix=algorithm ) @@ -169,13 +190,26 @@ def create_cached_legacy_composite( algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, - ) -> CompositeFileHasher: + ) -> LegacyCompositeFileHasher: """Create a composite hasher with file caching.""" - basic_file_hasher = LegacyFileHasher(algorithm, buffer_size) - cached_file_hasher = CachedFileHasher(basic_file_hasher, string_cacher) - return LegacyCompositeFileHasher( + basic_file_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) + cached_file_hasher = LegacyCachedFileHasher(basic_file_hasher, string_cacher) + return LegacyDefaultCompositeFileHasher( cached_file_hasher, char_count, packet_prefix=algorithm ) + + @staticmethod + def create_legacy_file_hasher( + string_cacher: StringCacher | None = None, + algorithm: str = "sha256", + buffer_size: int = 65536, + ) -> LegacyFileHasher: + """Create just a file hasher, optionally with caching.""" + default_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) + if string_cacher is None: + return default_hasher + else: + return LegacyCachedFileHasher(default_hasher, string_cacher) @staticmethod def create_file_hasher( diff --git a/src/orcapod/hashing/legacy_core.py b/src/orcapod/hashing/legacy_core.py index cfe9c56..a5b4319 100644 --- a/src/orcapod/hashing/legacy_core.py +++ b/src/orcapod/hashing/legacy_core.py @@ -33,7 +33,7 @@ import xxhash -from orcapod.types import Packet, PathSet +from orcapod.types import Packet, PacketLike, PathSet from orcapod.utils.name import find_noncolliding_name # Configure logging with __name__ for proper hierarchy @@ -681,7 +681,7 @@ def hash_packet_with_psh( def hash_packet( - packet: Packet, + packet: PacketLike, algorithm: str = "sha256", buffer_size: int = 65536, char_count: Optional[int] = 32, diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index 10ed267..c7d79da 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -5,7 +5,7 @@ from typing import Any, Protocol, runtime_checkable import uuid -from orcapod.types import Packet, PathLike, PathSet, TypeSpec +from orcapod.types import PacketLike, PathLike, PathSet, TypeSpec import pyarrow as pa @@ -85,21 +85,6 @@ class FileContentHasher(Protocol): def hash_file(self, file_path: PathLike) -> bytes: ... -# Higher-level operations that compose file hashing -@runtime_checkable -class PathSetHasher(Protocol): - """Protocol for hashing pathsets (files, directories, collections).""" - - def hash_pathset(self, pathset: PathSet) -> bytes: ... - - -@runtime_checkable -class PacketHasher(Protocol): - """Protocol for hashing packets.""" - - def hash_packet(self, packet: Packet) -> str: ... - - @runtime_checkable class ArrowHasher(Protocol): """Protocol for hashing arrow packets.""" @@ -116,14 +101,6 @@ def set_cached(self, cache_key: str, value: str) -> None: ... def clear_cache(self) -> None: ... -# Combined interface for convenience (optional) -@runtime_checkable -class CompositeFileHasher(FileContentHasher, PathSetHasher, PacketHasher, Protocol): - """Combined interface for all file-related hashing operations.""" - - pass - - # Function hasher protocol @runtime_checkable class FunctionInfoExtractor(Protocol): @@ -153,3 +130,38 @@ def hash_column( def set_cacher(self, cacher: StringCacher) -> None: """Add a string cacher for caching hash values.""" pass + + +#---------------Legacy implementations and protocols to be deprecated--------------------- + + +@runtime_checkable +class LegacyFileHasher(Protocol): + """Protocol for file-related hashing.""" + + def hash_file(self, file_path: PathLike) -> str: ... + + +# Higher-level operations that compose file hashing +@runtime_checkable +class LegacyPathSetHasher(Protocol): + """Protocol for hashing pathsets (files, directories, collections).""" + + def hash_pathset(self, pathset: PathSet) -> str: ... + + +@runtime_checkable +class LegacyPacketHasher(Protocol): + """Protocol for hashing packets.""" + + def hash_packet(self, packet: PacketLike) -> str: ... + + +# Combined interface for convenience (optional) +@runtime_checkable +class LegacyCompositeFileHasher(LegacyFileHasher, LegacyPathSetHasher, LegacyPacketHasher, Protocol): + """Combined interface for all file-related hashing operations.""" + + pass + + diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py new file mode 100644 index 0000000..2bba49b --- /dev/null +++ b/src/orcapod/pipeline/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import Pipeline + +__all__ = [ + "Pipeline", +] \ No newline at end of file diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 74eb998..864f649 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -14,7 +14,7 @@ from orcapod.hashing import hash_to_hex from orcapod.core.tracker import GraphTracker -from orcapod.store import ArrowDataStore +from orcapod.stores import ArrowDataStore logger = logging.getLogger(__name__) @@ -33,15 +33,17 @@ class Pipeline(GraphTracker): def __init__( self, - name: str, - results_store: ArrowDataStore, + name: str | tuple[str, ...], pipeline_store: ArrowDataStore, + results_store: ArrowDataStore, auto_compile: bool = True, ) -> None: super().__init__() - self.name = name or f"pipeline_{id(self)}" - self.results_store = results_store + if not isinstance(name, tuple): + name = (name,) + self.name = name self.pipeline_store = pipeline_store + self.results_store = results_store self.labels_to_nodes = {} self.auto_compile = auto_compile self._dirty = False @@ -92,8 +94,9 @@ def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node input_nodes, output_store=self.results_store, tag_store=self.pipeline_store, + store_path_prefix=self.name, ) - return KernelNode(kernel, input_nodes, output_store=self.pipeline_store) + return KernelNode(kernel, input_nodes, output_store=self.pipeline_store, store_path_prefix=self.name) def compile(self): import networkx as nx diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/wrappers.py index e999714..720609e 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/wrappers.py @@ -1,6 +1,6 @@ from orcapod.core.pod import Pod, FunctionPod from orcapod.core import SyncStream, Source, Kernel -from orcapod.store import ArrowDataStore +from orcapod.stores import ArrowDataStore from orcapod.types import Tag, Packet, PacketLike, TypeSpec, default_registry from orcapod.types.typespec_utils import get_typespec_from_dict, union_typespecs, extract_function_typespecs from orcapod.types.semantic_type_registry import SemanticTypeRegistry @@ -167,6 +167,7 @@ def __init__( kernel: Kernel, input_streams: Collection[SyncStream], output_store: ArrowDataStore, + store_path_prefix: tuple[str, ...] | None = None, kernel_hasher: ObjectHasher | None = None, arrow_packet_hasher: ArrowHasher | None = None, packet_type_registry: SemanticTypeRegistry | None = None, @@ -175,6 +176,7 @@ def __init__( super().__init__(kernel, input_streams, **kwargs) self.output_store = output_store + self.store_path_prefix = store_path_prefix or () # These are configurable but are not expected to be modified except for special circumstances if kernel_hasher is None: @@ -214,7 +216,7 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): self.update_cached_values() def update_cached_values(self): - self.source_info = self.label, self.kernel_hasher.hash_to_hex(self.kernel) + self.source_info = self.store_path_prefix + (self.label, self.kernel_hasher.hash_to_hex(self.kernel)) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.packet_typespec is None: @@ -270,9 +272,9 @@ def post_call(self, tag: Tag, packet: Packet) -> None: output_table = self.output_converter.from_python_packet_to_arrow_table(merged_info) # TODO: revisit this logic output_id = self.arrow_hasher.hash_table(output_table) - if not self.output_store.get_record(*self.source_info, output_id): + if not self.output_store.get_record(self.source_info, output_id): self.output_store.add_record( - *self.source_info, + self.source_info, output_id, output_table, ) @@ -287,7 +289,7 @@ def output_iterator_completion_hook(self) -> None: @property def lazy_df(self) -> pl.LazyFrame | None: - return self.output_store.get_all_records_as_polars(*self.source_info) + return self.output_store.get_all_records_as_polars(self.source_info) @property def df(self) -> pl.DataFrame | None: @@ -345,6 +347,7 @@ def __init__( output_store: ArrowDataStore, tag_store: ArrowDataStore | None = None, label: str | None = None, + store_path_prefix: tuple[str, ...] | None = None, skip_memoization_lookup: bool = False, skip_memoization: bool = False, skip_tag_record: bool = False, @@ -361,6 +364,7 @@ def __init__( error_handling=error_handling, **kwargs, ) + self.store_path_prefix = store_path_prefix or () self.output_store = output_store self.tag_store = tag_store @@ -502,9 +506,9 @@ def _add_pipeline_record_with_packet_key(self, tag: Tag, packet_key: str, packet # TODO: add error handling # check if record already exists: - retrieved_table = self.tag_store.get_record(*self.source_info, entry_hash) + retrieved_table = self.tag_store.get_record(self.source_info, entry_hash) if retrieved_table is None: - self.tag_store.add_record(*self.source_info, entry_hash, table) + self.tag_store.add_record(self.source_info, entry_hash, table) return tag @@ -523,8 +527,7 @@ def _retrieve_memoized_with_packet_key(self, packet_key: str) -> Packet | None: """ logger.debug(f"Retrieving memoized packet with key {packet_key}") arrow_table = self.output_store.get_record( - self.function_pod.function_name, - self.function_pod_hash, + self.source_info, packet_key, ) if arrow_table is None: @@ -560,7 +563,7 @@ def _memoize_with_packet_key( # consider simpler alternative packets = self.output_converter.from_arrow_table_to_python_packets( self.output_store.add_record( - *self.source_info, + self.source_info, packet_key, self.output_converter.from_python_packet_to_arrow_table(output_packet), ) @@ -622,12 +625,12 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: return tag, output_packet def get_all_outputs(self) -> pl.LazyFrame | None: - return self.output_store.get_all_records_as_polars(*self.source_info) + return self.output_store.get_all_records_as_polars(self.source_info) def get_all_tags(self, with_packet_id: bool = False) -> pl.LazyFrame | None: if self.tag_store is None: raise ValueError("Tag store is not set, no tag record can be retrieved") - data = self.tag_store.get_all_records_as_polars(*self.source_info) + data = self.tag_store.get_all_records_as_polars(self.source_info) if not with_packet_id: return data.drop("__packet_key") if data is not None else None return data @@ -640,11 +643,11 @@ def get_all_entries_with_tags(self, keep_hidden_fields: bool = False) -> pl.Lazy if self.tag_store is None: raise ValueError("Tag store is not set, no tag record can be retrieved") - tag_records = self.tag_store.get_all_records_as_polars(*self.source_info) + tag_records = self.tag_store.get_all_records_as_polars(self.source_info) if tag_records is None: return None result_packets = self.output_store.get_records_by_ids_as_polars( - *self.source_info, + self.source_info, tag_records.collect()["__packet_key"], preserve_input_order=True, ) diff --git a/src/orcapod/store/__init__.py b/src/orcapod/stores/__init__.py similarity index 100% rename from src/orcapod/store/__init__.py rename to src/orcapod/stores/__init__.py diff --git a/src/orcapod/store/arrow_data_stores.py b/src/orcapod/stores/arrow_data_stores.py similarity index 98% rename from src/orcapod/store/arrow_data_stores.py rename to src/orcapod/stores/arrow_data_stores.py index e2c1376..2608cbc 100644 --- a/src/orcapod/store/arrow_data_stores.py +++ b/src/orcapod/stores/arrow_data_stores.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta import logging -from orcapod.store.types import DuplicateError +from orcapod.stores.types import DuplicateError # Module-level logger logger = logging.getLogger(__name__) @@ -24,30 +24,30 @@ def __init__(self): logger.info("Initialized MockArrowDataStore") def add_record( - self, source_name: str, source_id: str, entry_id: str, arrow_data: pa.Table + self, source_pathh: tuple[str, ...], source_id: str, entry_id: str, arrow_data: pa.Table ) -> pa.Table: """Add a record to the mock store.""" return arrow_data def get_record( - self, source_name: str, source_id: str, entry_id: str + self, source_path: tuple[str, ...], source_id: str, entry_id: str ) -> pa.Table | None: """Get a specific record.""" return None - def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + def get_all_records(self, source_path: tuple[str, ...], source_id: str) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" return None def get_all_records_as_polars( - self, source_name: str, source_id: str + self, source_path: tuple[str, ...], source_id: str ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a single Polars LazyFrame.""" return None def get_records_by_ids( self, - source_name: str, + source_path: tuple[str,...], source_id: str, entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, @@ -77,7 +77,7 @@ def get_records_by_ids( def get_records_by_ids_as_polars( self, - source_name: str, + source_path: tuple[str,...], source_id: str, entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, @@ -115,14 +115,13 @@ def __init__(self, duplicate_entry_behavior: str = "error"): f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" ) - def _get_source_key(self, source_name: str, source_id: str) -> str: + def _get_source_key(self, source_path: tuple[str, ...]) -> str: """Generate key for source storage.""" - return f"{source_name}:{source_id}" + return "/".join(source_path) def add_record( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_id: str, arrow_data: pa.Table, ignore_duplicate: bool = False, @@ -142,7 +141,7 @@ def add_record( Raises: ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' """ - source_key = self._get_source_key(source_name, source_id) + source_key = self._get_source_key(source_path) # Initialize source if it doesn't exist if source_key not in self._in_memory_store: @@ -154,7 +153,7 @@ def add_record( if entry_id in local_data: if not ignore_duplicate and self.duplicate_entry_behavior == "error": raise ValueError( - f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Entry '{entry_id}' already exists in {source_key}. " f"Use duplicate_entry_behavior='overwrite' to allow updates." ) @@ -166,18 +165,18 @@ def add_record( return arrow_data def get_record( - self, source_name: str, source_id: str, entry_id: str + self, source_path: tuple[str, ...], entry_id: str ) -> pa.Table | None: """Get a specific record.""" - source_key = self._get_source_key(source_name, source_id) + source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) return local_data.get(entry_id) def get_all_records( - self, source_name: str, source_id: str, add_entry_id_column: bool | str = False + self, source_path: tuple[str, ...], add_entry_id_column: bool | str = False ) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" - source_key = self._get_source_key(source_name, source_id) + source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) if not local_data: @@ -199,18 +198,17 @@ def get_all_records( return None def get_all_records_as_polars( - self, source_name: str, source_id: str + self, source_path: tuple[str, ...] ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a single Polars LazyFrame.""" - all_records = self.get_all_records(source_name, source_id) + all_records = self.get_all_records(source_path) if all_records is None: return None return pl.LazyFrame(all_records) def get_records_by_ids( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, preserve_input_order: bool = False, @@ -253,7 +251,7 @@ def get_records_by_ids( f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" ) - source_key = self._get_source_key(source_name, source_id) + source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) if not local_data: @@ -340,8 +338,7 @@ def get_records_by_ids( def get_records_by_ids_as_polars( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, preserve_input_order: bool = False, @@ -368,7 +365,7 @@ def get_records_by_ids_as_polars( """ # Get Arrow result and convert to Polars arrow_result = self.get_records_by_ids( - source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order + source_path, entry_ids, add_entry_id_column, preserve_input_order ) if arrow_result is None: @@ -464,7 +461,7 @@ def load_from_parquet(self, base_path: str | Path) -> None: continue source_id = source_id_dir.name - source_key = self._get_source_key(source_name, source_id) + source_key = self._get_source_key((source_name, source_id)) # Look for Parquet files in this directory parquet_files = list(source_id_dir.glob("*.parquet")) diff --git a/src/orcapod/stores/delta_table_arrow_data_store.py b/src/orcapod/stores/delta_table_arrow_data_store.py new file mode 100644 index 0000000..d4fcaf3 --- /dev/null +++ b/src/orcapod/stores/delta_table_arrow_data_store.py @@ -0,0 +1,559 @@ +import pyarrow as pa +import polars as pl +from pathlib import Path +from typing import Any, Union +import logging +from deltalake import DeltaTable, write_deltalake +from deltalake.exceptions import TableNotFoundError + +# Module-level logger +logger = logging.getLogger(__name__) + + +class DeltaTableArrowDataStore: + """ + Delta Table-based Arrow data store with flexible hierarchical path support. + + Uses tuple-based source paths for robust parameter handling: + - ("source_name", "source_id") -> source_name/source_id/ + - ("org", "project", "dataset") -> org/project/dataset/ + - ("year", "month", "day", "experiment") -> year/month/day/experiment/ + """ + + def __init__( + self, + base_path: str | Path, + duplicate_entry_behavior: str = "error", + create_base_path: bool = True, + max_hierarchy_depth: int = 10 + ): + """ + Initialize the DeltaTableArrowDataStore. + + Args: + base_path: Base directory path where Delta tables will be stored + duplicate_entry_behavior: How to handle duplicate entry_ids: + - 'error': Raise ValueError when entry_id already exists + - 'overwrite': Replace existing entry with new data + create_base_path: Whether to create the base path if it doesn't exist + max_hierarchy_depth: Maximum allowed depth for source paths (safety limit) + """ + # Validate duplicate behavior + if duplicate_entry_behavior not in ["error", "overwrite"]: + raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") + + self.duplicate_entry_behavior = duplicate_entry_behavior + self.base_path = Path(base_path) + self.max_hierarchy_depth = max_hierarchy_depth + + if create_base_path: + self.base_path.mkdir(parents=True, exist_ok=True) + elif not self.base_path.exists(): + raise ValueError(f"Base path {self.base_path} does not exist and create_base_path=False") + + # Cache for Delta tables to avoid repeated initialization + self._delta_table_cache: dict[str, DeltaTable] = {} + + logger.info( + f"Initialized DeltaTableArrowDataStore at {self.base_path} " + f"with duplicate_entry_behavior='{duplicate_entry_behavior}'" + ) + + def _validate_source_path(self, source_path: tuple[str, ...]) -> None: + """ + Validate source path components. + + Args: + source_path: Tuple of path components + + Raises: + ValueError: If path is invalid + """ + if not source_path: + raise ValueError("Source path cannot be empty") + + if len(source_path) > self.max_hierarchy_depth: + raise ValueError(f"Source path depth {len(source_path)} exceeds maximum {self.max_hierarchy_depth}") + + # Validate path components + for i, component in enumerate(source_path): + if not component or not isinstance(component, str): + raise ValueError(f"Source path component {i} is invalid: {repr(component)}") + + # Check for filesystem-unsafe characters + unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\0'] + if any(char in component for char in unsafe_chars): + raise ValueError(f"Source path component contains invalid characters: {repr(component)}") + + def _get_source_key(self, source_path: tuple[str, ...]) -> str: + """Generate cache key for source storage.""" + return "/".join(source_path) + + def _get_table_path(self, source_path: tuple[str, ...]) -> Path: + """Get the filesystem path for a given source path.""" + path = self.base_path + for component in source_path: + path = path / component + return path + + def _ensure_entry_id_column(self, arrow_data: pa.Table, entry_id: str) -> pa.Table: + """Ensure the table has an __entry_id column.""" + if "__entry_id" not in arrow_data.column_names: + # Add entry_id column at the beginning + key_array = pa.array([entry_id] * len(arrow_data), type=pa.large_string()) + arrow_data = arrow_data.add_column(0, "__entry_id", key_array) + return arrow_data + + def _remove_entry_id_column(self, arrow_data: pa.Table) -> pa.Table: + """Remove the __entry_id column if it exists.""" + if "__entry_id" in arrow_data.column_names: + column_names = arrow_data.column_names + indices_to_keep = [ + i for i, name in enumerate(column_names) if name != "__entry_id" + ] + arrow_data = arrow_data.select(indices_to_keep) + return arrow_data + + def _handle_entry_id_column( + self, + arrow_data: pa.Table, + add_entry_id_column: bool | str = False + ) -> pa.Table: + """ + Handle entry_id column based on add_entry_id_column parameter. + + Args: + arrow_data: Arrow table with __entry_id column + add_entry_id_column: Control entry ID column inclusion: + - False: Remove __entry_id column + - True: Keep __entry_id column as is + - str: Rename __entry_id column to custom name + """ + if add_entry_id_column is False: + # Remove the __entry_id column + return self._remove_entry_id_column(arrow_data) + elif isinstance(add_entry_id_column, str): + # Rename __entry_id to custom name + if "__entry_id" in arrow_data.column_names: + schema = arrow_data.schema + new_names = [ + add_entry_id_column if name == "__entry_id" else name + for name in schema.names + ] + return arrow_data.rename_columns(new_names) + # If add_entry_id_column is True, keep __entry_id as is + return arrow_data + + def add_record( + self, + source_path: tuple[str, ...], + entry_id: str, + arrow_data: pa.Table, + ignore_duplicate: bool = False, + ) -> pa.Table: + """ + Add a record to the Delta table. + + Args: + source_path: Tuple of path components (e.g., ("org", "project", "dataset")) + entry_id: Unique identifier for this record + arrow_data: The Arrow table data to store + ignore_duplicate: If True, ignore duplicate entry error + + Returns: + The Arrow table data that was stored + + Raises: + ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + source_key = self._get_source_key(source_path) + + # Ensure directory exists + table_path.mkdir(parents=True, exist_ok=True) + + # Add entry_id column to the data + data_with_entry_id = self._ensure_entry_id_column(arrow_data, entry_id) + + # Check for existing entry if needed + if not ignore_duplicate and self.duplicate_entry_behavior == "error": + existing_record = self.get_record(source_path, entry_id) + if existing_record is not None: + raise ValueError( + f"Entry '{entry_id}' already exists in {'/'.join(source_path)}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + + try: + # Try to load existing table + delta_table = DeltaTable(str(table_path)) + + if self.duplicate_entry_behavior == "overwrite": + # Delete existing record if it exists, then append new one + try: + # First, delete existing record with this entry_id + delta_table.delete(f"__entry_id = '{entry_id}'") + logger.debug(f"Deleted existing record {entry_id} from {source_key}") + except Exception as e: + # If delete fails (e.g., record doesn't exist), that's fine + logger.debug(f"No existing record to delete for {entry_id}: {e}") + + # Append new record + write_deltalake( + str(table_path), + data_with_entry_id, + mode="append", + schema_mode="merge" + ) + + except TableNotFoundError: + # Table doesn't exist, create it + write_deltalake( + str(table_path), + data_with_entry_id, + mode="overwrite" + ) + logger.debug(f"Created new Delta table for {source_key}") + + # Update cache + self._delta_table_cache[source_key] = DeltaTable(str(table_path)) + + logger.debug(f"Added record {entry_id} to {source_key}") + return arrow_data + + def get_record( + self, source_path: tuple[str, ...], entry_id: str + ) -> pa.Table | None: + """ + Get a specific record by entry_id. + + Args: + source_path: Tuple of path components + entry_id: Unique identifier for the record + + Returns: + Arrow table for the record, or None if not found + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + + try: + delta_table = DeltaTable(str(table_path)) + + # Query for the specific entry_id + result = delta_table.to_pyarrow_table( + filter=f"__entry_id = '{entry_id}'" + ) + + if len(result) == 0: + return None + + # Remove the __entry_id column before returning + return self._remove_entry_id_column(result) + + except TableNotFoundError: + return None + except Exception as e: + logger.error(f"Error getting record {entry_id} from {'/'.join(source_path)}: {e}") + return None + + def get_all_records( + self, source_path: tuple[str, ...], add_entry_id_column: bool | str = False + ) -> pa.Table | None: + """ + Retrieve all records for a given source path as a single table. + + Args: + source_path: Tuple of path components + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + + Returns: + Arrow table containing all records, or None if no records found + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + + try: + delta_table = DeltaTable(str(table_path)) + result = delta_table.to_pyarrow_table() + + if len(result) == 0: + return None + + # Handle entry_id column based on parameter + return self._handle_entry_id_column(result, add_entry_id_column) + + except TableNotFoundError: + return None + except Exception as e: + logger.error(f"Error getting all records from {'/'.join(source_path)}: {e}") + return None + + def get_all_records_as_polars( + self, source_path: tuple[str, ...] + ) -> pl.LazyFrame | None: + """ + Retrieve all records for a given source path as a single Polars LazyFrame. + + Args: + source_path: Tuple of path components + + Returns: + Polars LazyFrame containing all records, or None if no records found + """ + all_records = self.get_all_records(source_path) + if all_records is None: + return None + return pl.LazyFrame(all_records) + + def get_records_by_ids( + self, + source_path: tuple[str, ...], + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """ + Retrieve records by entry IDs as a single table. + + Args: + source_path: Tuple of path components + entry_ids: Entry IDs to retrieve + add_entry_id_column: Control entry ID column inclusion + preserve_input_order: If True, return results in input order with nulls for missing + + Returns: + Arrow table containing all found records, or None if no records found + """ + self._validate_source_path(source_path) + + # Convert input to list of strings for consistency + if isinstance(entry_ids, list): + if not entry_ids: + return None + entry_ids_list = entry_ids + elif isinstance(entry_ids, pl.Series): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_list() + elif isinstance(entry_ids, pa.Array): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_pylist() + else: + raise TypeError( + f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" + ) + + table_path = self._get_table_path(source_path) + + try: + delta_table = DeltaTable(str(table_path)) + + # Create filter for the entry IDs - escape single quotes in IDs + escaped_ids = [id_.replace("'", "''") for id_ in entry_ids_list] + id_filter = " OR ".join([f"__entry_id = '{id_}'" for id_ in escaped_ids]) + + result = delta_table.to_pyarrow_table(filter=id_filter) + + if len(result) == 0: + return None + + if preserve_input_order: + # Need to reorder results and add nulls for missing entries + import pandas as pd + + df = result.to_pandas() + df = df.set_index('__entry_id') + + # Create a DataFrame with the desired order, filling missing with NaN + ordered_df = df.reindex(entry_ids_list) + + # Convert back to Arrow + result = pa.Table.from_pandas(ordered_df.reset_index()) + + # Handle entry_id column based on parameter + return self._handle_entry_id_column(result, add_entry_id_column) + + except TableNotFoundError: + return None + except Exception as e: + logger.error(f"Error getting records by IDs from {'/'.join(source_path)}: {e}") + return None + + def get_records_by_ids_as_polars( + self, + source_path: tuple[str, ...], + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """ + Retrieve records by entry IDs as a single Polars LazyFrame. + + Args: + source_path: Tuple of path components + entry_ids: Entry IDs to retrieve + add_entry_id_column: Control entry ID column inclusion + preserve_input_order: If True, return results in input order with nulls for missing + + Returns: + Polars LazyFrame containing all found records, or None if no records found + """ + arrow_result = self.get_records_by_ids( + source_path, entry_ids, add_entry_id_column, preserve_input_order + ) + + if arrow_result is None: + return None + + # Convert to Polars LazyFrame + return pl.LazyFrame(arrow_result) + + # Additional utility methods + def list_sources(self) -> list[tuple[str, ...]]: + """ + List all available source paths. + + Returns: + List of source path tuples + """ + sources = [] + + def _scan_directory(current_path: Path, path_components: tuple[str, ...]): + """Recursively scan for Delta tables.""" + for item in current_path.iterdir(): + if not item.is_dir(): + continue + + new_path_components = path_components + (item.name,) + + # Check if this directory contains a Delta table + try: + DeltaTable(str(item)) + sources.append(new_path_components) + except TableNotFoundError: + # Not a Delta table, continue scanning subdirectories + if len(new_path_components) < self.max_hierarchy_depth: + _scan_directory(item, new_path_components) + + _scan_directory(self.base_path, ()) + return sources + + def delete_source(self, source_path: tuple[str, ...]) -> bool: + """ + Delete an entire source (all records for a source path). + + Args: + source_path: Tuple of path components + + Returns: + True if source was deleted, False if it didn't exist + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + source_key = self._get_source_key(source_path) + + if not table_path.exists(): + return False + + try: + # Remove from cache + if source_key in self._delta_table_cache: + del self._delta_table_cache[source_key] + + # Remove directory + import shutil + shutil.rmtree(table_path) + + logger.info(f"Deleted source {source_key}") + return True + + except Exception as e: + logger.error(f"Error deleting source {source_key}: {e}") + return False + + def delete_record(self, source_path: tuple[str, ...], entry_id: str) -> bool: + """ + Delete a specific record. + + Args: + source_path: Tuple of path components + entry_id: ID of the record to delete + + Returns: + True if record was deleted, False if it didn't exist + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + + try: + delta_table = DeltaTable(str(table_path)) + + # Check if record exists + escaped_entry_id = entry_id.replace("'", "''") + existing = delta_table.to_pyarrow_table(filter=f"__entry_id = '{escaped_entry_id}'") + if len(existing) == 0: + return False + + # Delete the record + delta_table.delete(f"__entry_id = '{escaped_entry_id}'") + + # Update cache + source_key = self._get_source_key(source_path) + self._delta_table_cache[source_key] = delta_table + + logger.debug(f"Deleted record {entry_id} from {'/'.join(source_path)}") + return True + + except TableNotFoundError: + return False + except Exception as e: + logger.error(f"Error deleting record {entry_id} from {'/'.join(source_path)}: {e}") + return False + + def get_table_info(self, source_path: tuple[str, ...]) -> dict[str, Any] | None: + """ + Get metadata information about a Delta table. + + Args: + source_path: Tuple of path components + + Returns: + Dictionary with table metadata, or None if table doesn't exist + """ + self._validate_source_path(source_path) + + table_path = self._get_table_path(source_path) + + try: + delta_table = DeltaTable(str(table_path)) + + # Get basic info + schema = delta_table.schema() + history = delta_table.history() + + return { + "path": str(table_path), + "source_path": source_path, + "schema": schema, + "version": delta_table.version(), + "num_files": len(delta_table.files()), + "history_length": len(history), + "latest_commit": history[0] if history else None, + } + + except TableNotFoundError: + return None + except Exception as e: + logger.error(f"Error getting table info for {'/'.join(source_path)}: {e}") + return None \ No newline at end of file diff --git a/src/orcapod/store/dict_data_stores.py b/src/orcapod/stores/dict_data_stores.py similarity index 95% rename from src/orcapod/store/dict_data_stores.py rename to src/orcapod/stores/dict_data_stores.py index c41dd55..edb44e5 100644 --- a/src/orcapod/store/dict_data_stores.py +++ b/src/orcapod/stores/dict_data_stores.py @@ -5,10 +5,10 @@ from pathlib import Path from orcapod.hashing import hash_packet +from orcapod.hashing.types import LegacyPacketHasher from orcapod.hashing.defaults import get_default_composite_file_hasher -from orcapod.hashing.types import PacketHasher -from orcapod.store.types import DataStore -from orcapod.types import Packet +from orcapod.stores.types import DataStore +from orcapod.types import Packet, PacketLike logger = logging.getLogger(__name__) @@ -30,15 +30,15 @@ def memoize( self, function_name: str, function_hash: str, - packet: Packet, - output_packet: Packet, + packet: PacketLike, + output_packet: PacketLike, overwrite: bool = False, - ) -> Packet: + ) -> PacketLike: return output_packet def retrieve_memoized( - self, function_name: str, function_hash: str, packet: Packet - ) -> Packet | None: + self, function_name: str, function_hash: str, packet: PacketLike + ) -> PacketLike | None: return None @@ -46,7 +46,7 @@ class DirDataStore(DataStore): def __init__( self, store_dir: str | PathLike = "./pod_data", - packet_hasher: PacketHasher | None = None, + packet_hasher: LegacyPacketHasher | None = None, copy_files=True, preserve_filename=True, overwrite=False, @@ -71,9 +71,9 @@ def memoize( self, function_name: str, function_hash: str, - packet: Packet, - output_packet: Packet, - ) -> Packet: + packet: PacketLike, + output_packet: PacketLike, + ) -> PacketLike: if self.legacy_mode: packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) else: @@ -139,7 +139,7 @@ def memoize( return retrieved_output_packet def retrieve_memoized( - self, function_name: str, function_hash: str, packet: Packet + self, function_name: str, function_hash: str, packet: PacketLike ) -> Packet | None: if self.legacy_mode: packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) diff --git a/src/orcapod/store/file_ops.py b/src/orcapod/stores/file_ops.py similarity index 100% rename from src/orcapod/store/file_ops.py rename to src/orcapod/stores/file_ops.py diff --git a/src/orcapod/store/optimized_memory_store.py b/src/orcapod/stores/optimized_memory_store.py similarity index 100% rename from src/orcapod/store/optimized_memory_store.py rename to src/orcapod/stores/optimized_memory_store.py diff --git a/src/orcapod/store/safe_dir_data_store.py b/src/orcapod/stores/safe_dir_data_store.py similarity index 100% rename from src/orcapod/store/safe_dir_data_store.py rename to src/orcapod/stores/safe_dir_data_store.py diff --git a/src/orcapod/store/transfer_data_store.py b/src/orcapod/stores/transfer_data_store.py similarity index 90% rename from src/orcapod/store/transfer_data_store.py rename to src/orcapod/stores/transfer_data_store.py index c9a4e5d..0c8e215 100644 --- a/src/orcapod/store/transfer_data_store.py +++ b/src/orcapod/stores/transfer_data_store.py @@ -1,7 +1,7 @@ # Implements transfer data store that lets you transfer memoized packets between data stores. -from orcapod.store.types import DataStore -from orcapod.types import Packet +from orcapod.stores.types import DataStore +from orcapod.types import PacketLike class TransferDataStore(DataStore): @@ -14,7 +14,7 @@ def __init__(self, source_store: DataStore, target_store: DataStore) -> None: self.source_store = source_store self.target_store = target_store - def transfer(self, function_name: str, content_hash: str, packet: Packet) -> Packet: + def transfer(self, function_name: str, content_hash: str, packet: PacketLike) -> PacketLike: """ Transfer a memoized packet from the source store to the target store. """ @@ -29,8 +29,8 @@ def transfer(self, function_name: str, content_hash: str, packet: Packet) -> Pac ) def retrieve_memoized( - self, function_name: str, function_hash: str, packet: Packet - ) -> Packet | None: + self, function_name: str, function_hash: str, packet: PacketLike + ) -> PacketLike | None: """ Retrieve a memoized packet from the target store. """ @@ -57,9 +57,9 @@ def memoize( self, function_name: str, function_hash: str, - packet: Packet, - output_packet: Packet, - ) -> Packet: + packet: PacketLike, + output_packet: PacketLike, + ) -> PacketLike: """ Memoize a packet in the target store. """ diff --git a/src/orcapod/store/types.py b/src/orcapod/stores/types.py similarity index 79% rename from src/orcapod/store/types.py rename to src/orcapod/stores/types.py index 49d9a70..c588856 100644 --- a/src/orcapod/store/types.py +++ b/src/orcapod/stores/types.py @@ -1,6 +1,6 @@ from typing import Protocol, runtime_checkable -from orcapod.types import Tag, Packet +from orcapod.types import Tag, PacketLike import pyarrow as pa import polars as pl @@ -21,13 +21,13 @@ def memoize( self, function_name: str, function_hash: str, - packet: Packet, - output_packet: Packet, - ) -> Packet: ... + packet: PacketLike, + output_packet: PacketLike, + ) -> PacketLike: ... def retrieve_memoized( - self, function_name: str, function_hash: str, packet: Packet - ) -> Packet | None: ... + self, function_name: str, function_hash: str, packet: PacketLike + ) -> PacketLike | None: ... @runtime_checkable @@ -41,31 +41,29 @@ def __init__(self, *args, **kwargs) -> None: ... def add_record( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_id: str, arrow_data: pa.Table, ignore_duplicate: bool = False, ) -> pa.Table: ... def get_record( - self, source_name: str, source_id: str, entry_id: str + self, source_path: tuple[str,...], entry_id: str ) -> pa.Table | None: ... - def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + def get_all_records(self, source_path: tuple[str,...]) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" ... def get_all_records_as_polars( - self, source_name: str, source_id: str + self, source_path: tuple[str,...] ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a single Polars DataFrame.""" ... def get_records_by_ids( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, preserve_input_order: bool = False, @@ -75,8 +73,7 @@ def get_records_by_ids( def get_records_by_ids_as_polars( self, - source_name: str, - source_id: str, + source_path: tuple[str, ...], entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, preserve_input_order: bool = False, diff --git a/tests/test_hashing/test_basic_composite_hasher.py b/tests/test_hashing/test_basic_composite_hasher.py index 2ef9cf6..f2da406 100644 --- a/tests/test_hashing/test_basic_composite_hasher.py +++ b/tests/test_hashing/test_basic_composite_hasher.py @@ -12,7 +12,7 @@ import pytest -from orcapod.hashing.file_hashers import PathLikeHasherFactory +from orcapod.hashing.file_hashers import LegacyPathLikeHasherFactory def load_hash_lut(): @@ -82,7 +82,7 @@ def verify_path_exists(rel_path): def test_default_file_hasher_file_hash_consistency(): """Test that DefaultFileHasher.hash_file produces consistent results for the sample files.""" hash_lut = load_hash_lut() - hasher = PathLikeHasherFactory.create_basic_composite() + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite() for filename, info in hash_lut.items(): rel_path = info["file"] @@ -104,7 +104,7 @@ def test_default_file_hasher_file_hash_consistency(): def test_default_file_hasher_pathset_hash_consistency(): """Test that DefaultFileHasher.hash_pathset produces consistent results for the sample pathsets.""" hash_lut = load_pathset_hash_lut() - hasher = PathLikeHasherFactory.create_basic_composite() + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite() for name, info in hash_lut.items(): paths_rel = info["paths"] @@ -137,7 +137,7 @@ def test_default_file_hasher_pathset_hash_consistency(): def test_default_file_hasher_packet_hash_consistency(): """Test that DefaultFileHasher.hash_packet produces consistent results for the sample packets.""" hash_lut = load_packet_hash_lut() - hasher = PathLikeHasherFactory.create_basic_composite() + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite() for name, info in hash_lut.items(): structure = info["structure"] @@ -181,7 +181,7 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -193,7 +193,7 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -222,7 +222,7 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -234,7 +234,7 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -266,7 +266,7 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) @@ -285,7 +285,7 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" diff --git a/tests/test_hashing/test_cached_file_hasher.py b/tests/test_hashing/test_cached_file_hasher.py index 42c9380..d7514a5 100644 --- a/tests/test_hashing/test_cached_file_hasher.py +++ b/tests/test_hashing/test_cached_file_hasher.py @@ -14,7 +14,7 @@ CachedFileHasher, ) from orcapod.hashing.string_cachers import InMemoryCacher -from orcapod.hashing.types import FileHasher, StringCacher +from orcapod.hashing.types import LegacyFileHasher, StringCacher def verify_path_exists(rel_path): @@ -81,7 +81,7 @@ def test_cached_file_hasher_construction(): assert cached_hasher1.string_cacher == string_cacher # Test that CachedFileHasher implements FileHasher protocol - assert isinstance(cached_hasher1, FileHasher) + assert isinstance(cached_hasher1, LegacyFileHasher) def test_cached_file_hasher_file_caching(): @@ -136,7 +136,7 @@ def test_cached_file_hasher_call_counts(): try: # Mock the file_hasher to track calls - mock_file_hasher = MagicMock(spec=FileHasher) + mock_file_hasher = MagicMock(spec=LegacyFileHasher) mock_file_hasher.hash_file.return_value = "mock_file_hash" # Real cacher diff --git a/tests/test_hashing/test_hasher_factory.py b/tests/test_hashing/test_hasher_factory.py index afd2392..5776a2d 100644 --- a/tests/test_hashing/test_hasher_factory.py +++ b/tests/test_hashing/test_hasher_factory.py @@ -5,9 +5,9 @@ from pathlib import Path from orcapod.hashing.file_hashers import ( - BasicFileHasher, - CachedFileHasher, - PathLikeHasherFactory, + LegacyDefaultFileHasher, + LegacyCachedFileHasher, + LegacyPathLikeHasherFactory, ) from orcapod.hashing.string_cachers import FileCacher, InMemoryCacher @@ -17,11 +17,11 @@ class TestPathLikeHasherFactoryCreateFileHasher: def test_create_file_hasher_without_cacher(self): """Test creating a file hasher without string cacher (returns BasicFileHasher).""" - hasher = PathLikeHasherFactory.create_file_hasher() + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher() - # Should return BasicFileHasher - assert isinstance(hasher, BasicFileHasher) - assert not isinstance(hasher, CachedFileHasher) + # Should return LegacyDefaultFileHasher + assert isinstance(hasher, LegacyDefaultFileHasher) + assert not isinstance(hasher, LegacyCachedFileHasher) # Check default parameters assert hasher.algorithm == "sha256" @@ -30,60 +30,63 @@ def test_create_file_hasher_without_cacher(self): def test_create_file_hasher_with_cacher(self): """Test creating a file hasher with string cacher (returns CachedFileHasher).""" cacher = InMemoryCacher() - hasher = PathLikeHasherFactory.create_file_hasher(string_cacher=cacher) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(string_cacher=cacher) - # Should return CachedFileHasher - assert isinstance(hasher, CachedFileHasher) + # Should return LegacyCachedFileHasher + assert isinstance(hasher, LegacyCachedFileHasher) assert hasher.string_cacher is cacher - # The underlying file hasher should be BasicFileHasher with defaults - assert isinstance(hasher.file_hasher, BasicFileHasher) + # The underlying file hasher should be LegacyDefaultFileHasher with defaults + assert isinstance(hasher.file_hasher, LegacyDefaultFileHasher) assert hasher.file_hasher.algorithm == "sha256" assert hasher.file_hasher.buffer_size == 65536 def test_create_file_hasher_custom_algorithm(self): """Test creating file hasher with custom algorithm.""" # Without cacher - hasher = PathLikeHasherFactory.create_file_hasher(algorithm="md5") - assert isinstance(hasher, BasicFileHasher) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(algorithm="md5") + assert isinstance(hasher, LegacyDefaultFileHasher) assert hasher.algorithm == "md5" assert hasher.buffer_size == 65536 # With cacher cacher = InMemoryCacher() - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, algorithm="sha512" ) - assert isinstance(hasher, CachedFileHasher) - assert hasher.file_hasher.algorithm == "sha512" - assert hasher.file_hasher.buffer_size == 65536 + assert isinstance(hasher, LegacyCachedFileHasher) + assert isinstance(hasher.file_hasher, LegacyDefaultFileHasher) + assert hasher.file_hasher.algorithm == "sha512" + assert hasher.file_hasher.buffer_size == 65536 def test_create_file_hasher_custom_buffer_size(self): """Test creating file hasher with custom buffer size.""" # Without cacher - hasher = PathLikeHasherFactory.create_file_hasher(buffer_size=32768) - assert isinstance(hasher, BasicFileHasher) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(buffer_size=32768) + assert isinstance(hasher, LegacyDefaultFileHasher) assert hasher.algorithm == "sha256" assert hasher.buffer_size == 32768 # With cacher cacher = InMemoryCacher() - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, buffer_size=8192 ) - assert isinstance(hasher, CachedFileHasher) + assert isinstance(hasher, LegacyCachedFileHasher) + assert isinstance(hasher.file_hasher, LegacyDefaultFileHasher) assert hasher.file_hasher.algorithm == "sha256" assert hasher.file_hasher.buffer_size == 8192 def test_create_file_hasher_all_custom_parameters(self): """Test creating file hasher with all custom parameters.""" cacher = InMemoryCacher(max_size=500) - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="blake2b", buffer_size=16384 ) - assert isinstance(hasher, CachedFileHasher) + assert isinstance(hasher, LegacyCachedFileHasher) assert hasher.string_cacher is cacher + assert isinstance(hasher.file_hasher, LegacyDefaultFileHasher) assert hasher.file_hasher.algorithm == "blake2b" assert hasher.file_hasher.buffer_size == 16384 @@ -91,17 +94,17 @@ def test_create_file_hasher_different_cacher_types(self): """Test creating file hasher with different types of string cachers.""" # InMemoryCacher memory_cacher = InMemoryCacher() - hasher1 = PathLikeHasherFactory.create_file_hasher(string_cacher=memory_cacher) - assert isinstance(hasher1, CachedFileHasher) + hasher1 = LegacyPathLikeHasherFactory.create_file_hasher(string_cacher=memory_cacher) + assert isinstance(hasher1, LegacyCachedFileHasher) assert hasher1.string_cacher is memory_cacher # FileCacher with tempfile.NamedTemporaryFile(delete=False) as tmp_file: file_cacher = FileCacher(tmp_file.name) - hasher2 = PathLikeHasherFactory.create_file_hasher( + hasher2 = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=file_cacher ) - assert isinstance(hasher2, CachedFileHasher) + assert isinstance(hasher2, LegacyCachedFileHasher) assert hasher2.string_cacher is file_cacher # Clean up @@ -109,7 +112,7 @@ def test_create_file_hasher_different_cacher_types(self): def test_create_file_hasher_functional_without_cache(self): """Test that created file hasher actually works for hashing files.""" - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( algorithm="sha256", buffer_size=1024 ) @@ -136,7 +139,7 @@ def test_create_file_hasher_functional_without_cache(self): def test_create_file_hasher_functional_with_cache(self): """Test that created cached file hasher works and caches results.""" cacher = InMemoryCacher() - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, algorithm="sha256" ) @@ -164,44 +167,51 @@ def test_create_file_hasher_functional_with_cache(self): def test_create_file_hasher_none_cacher_explicit(self): """Test explicitly passing None for string_cacher.""" - hasher = PathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=None, algorithm="sha1", buffer_size=4096 ) - assert isinstance(hasher, BasicFileHasher) - assert not isinstance(hasher, CachedFileHasher) + assert isinstance(hasher, LegacyDefaultFileHasher) + assert not isinstance(hasher, LegacyCachedFileHasher) assert hasher.algorithm == "sha1" assert hasher.buffer_size == 4096 def test_create_file_hasher_parameter_edge_cases(self): """Test edge cases for parameters.""" # Very small buffer size - hasher1 = PathLikeHasherFactory.create_file_hasher(buffer_size=1) + hasher1 = LegacyPathLikeHasherFactory.create_legacy_file_hasher(buffer_size=1) + assert isinstance(hasher1, LegacyDefaultFileHasher) assert hasher1.buffer_size == 1 # Large buffer size - hasher2 = PathLikeHasherFactory.create_file_hasher(buffer_size=1024 * 1024) + hasher2 = LegacyPathLikeHasherFactory.create_legacy_file_hasher(buffer_size=1024 * 1024) + assert isinstance(hasher2, LegacyDefaultFileHasher) assert hasher2.buffer_size == 1024 * 1024 # Different algorithms for algorithm in ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]: - hasher = PathLikeHasherFactory.create_file_hasher(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(algorithm=algorithm) + assert isinstance(hasher, LegacyDefaultFileHasher) assert hasher.algorithm == algorithm def test_create_file_hasher_cache_independence(self): """Test that different cached hashers with same cacher are independent.""" cacher = InMemoryCacher() - hasher1 = PathLikeHasherFactory.create_file_hasher( + hasher1 = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, algorithm="sha256" ) - hasher2 = PathLikeHasherFactory.create_file_hasher( + hasher2 = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, algorithm="md5" ) # Both should use the same cacher but be different instances + assert isinstance(hasher1, LegacyCachedFileHasher) + assert isinstance(hasher2, LegacyCachedFileHasher) assert hasher1.string_cacher is cacher assert hasher2.string_cacher is cacher assert hasher1 is not hasher2 assert hasher1.file_hasher is not hasher2.file_hasher + assert isinstance(hasher1.file_hasher, LegacyDefaultFileHasher) + assert isinstance(hasher2.file_hasher, LegacyDefaultFileHasher) assert hasher1.file_hasher.algorithm != hasher2.file_hasher.algorithm diff --git a/tests/test_hashing/test_hasher_parity.py b/tests/test_hashing/test_hasher_parity.py index 64a6004..a278a92 100644 --- a/tests/test_hashing/test_hasher_parity.py +++ b/tests/test_hashing/test_hasher_parity.py @@ -13,8 +13,8 @@ import pytest -from orcapod.hashing.core import hash_file, hash_packet, hash_pathset -from orcapod.hashing.file_hashers import PathLikeHasherFactory +from orcapod.hashing.legacy_core import hash_file, hash_packet, hash_pathset +from orcapod.hashing.file_hashers import LegacyPathLikeHasherFactory def load_hash_lut(): @@ -73,7 +73,7 @@ def verify_path_exists(rel_path): def test_hasher_core_parity_file_hash(): """Test that BasicFileHasher.hash_file produces the same results as hash_file.""" hash_lut = load_hash_lut() - hasher = PathLikeHasherFactory.create_basic_composite() + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite() # Test all sample files for filename, info in hash_lut.items(): @@ -102,7 +102,7 @@ def test_hasher_core_parity_file_hash(): for buffer_size in buffer_sizes: try: # Create a hasher with specific parameters - hasher = PathLikeHasherFactory.create_basic_composite( + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( algorithm=algorithm, buffer_size=buffer_size ) @@ -147,7 +147,7 @@ def test_hasher_core_parity_pathset_hash(): for buffer_size in buffer_sizes: for char_count in char_counts: # Create a hasher with specific parameters - hasher = PathLikeHasherFactory.create_basic_composite( + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( algorithm=algorithm, buffer_size=buffer_size, char_count=char_count, @@ -201,7 +201,7 @@ def test_hasher_core_parity_packet_hash(): for buffer_size in buffer_sizes: for char_count in char_counts: # Create a hasher with specific parameters - hasher = PathLikeHasherFactory.create_basic_composite( + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( algorithm=algorithm, buffer_size=buffer_size, char_count=char_count, diff --git a/tests/test_hashing/test_legacy_composite_hasher.py b/tests/test_hashing/test_legacy_composite_hasher.py new file mode 100644 index 0000000..c9a3ee2 --- /dev/null +++ b/tests/test_hashing/test_legacy_composite_hasher.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +"""Tests for the CompositeFileHasher implementation.""" + +from unittest.mock import patch + +import pytest + +from orcapod.hashing.legacy_core import hash_to_hex +from orcapod.hashing.file_hashers import BasicFileHasher, LegacyDefaultCompositeFileHasher +from orcapod.hashing.types import LegacyFileHasher, LegacyPacketHasher, LegacyPathSetHasher + + +# Custom implementation of hash_file for tests that doesn't check for file existence +def mock_hash_file(file_path, algorithm="sha256", buffer_size=65536) -> str: + """Mock implementation of hash_file that doesn't check for file existence.""" + # Simply return a deterministic hash based on the file path + return hash_to_hex(f"mock_file_hash_{file_path}_{algorithm}") + + +# Custom implementation of hash_pathset for tests that doesn't check for file existence +def mock_hash_pathset( + pathset, algorithm="sha256", buffer_size=65536, char_count=32, file_hasher=None +): + """Mock implementation of hash_pathset that doesn't check for file existence.""" + from collections.abc import Collection + from os import PathLike + from pathlib import Path + + # If file_hasher is None, we'll need to handle it differently + if file_hasher is None: + # Just return a mock hash for testing + if isinstance(pathset, (str, Path, PathLike)): + return f"mock_{pathset}" + return "mock_hash" + + # Handle dictionary case for nested paths + if isinstance(pathset, dict): + hash_dict = {} + for key, value in pathset.items(): + hash_dict[key] = mock_hash_pathset( + value, algorithm, buffer_size, char_count, file_hasher + ) + return hash_to_hex(str(hash_dict)) + + # Handle collection case (list, set, etc.) + if isinstance(pathset, Collection) and not isinstance( + pathset, (str, Path, PathLike) + ): + hash_list = [] + for item in pathset: + hash_list.append( + mock_hash_pathset(item, algorithm, buffer_size, char_count, file_hasher) + ) + return hash_to_hex(str(hash_list)) + + # Handle simple string or Path case + if isinstance(pathset, (str, Path, PathLike)): + if hasattr(file_hasher, "__self__"): # For bound methods + return file_hasher(str(pathset)) + else: + return file_hasher(str(pathset)) + + return "mock_hash" + + +# Custom implementation of hash_packet for tests that doesn't check for file existence +def mock_hash_packet( + packet, + algorithm="sha256", + buffer_size=65536, + char_count=32, + prefix_algorithm=True, + pathset_hasher=None, +): + """Mock implementation of hash_packet that doesn't check for file existence.""" + # Create a simple hash based on the packet structure + hash_value = hash_to_hex(str(packet)) + + # Format it like the real function would + if prefix_algorithm and algorithm: + return ( + f"{algorithm}-{hash_value[: char_count if char_count else len(hash_value)]}" + ) + else: + return hash_value[: char_count if char_count else len(hash_value)] + + +@pytest.fixture(autouse=True) +def patch_hash_functions(): + """Patch the hash functions in the core module for all tests.""" + with ( + patch("orcapod.hashing.core.hash_file", side_effect=mock_hash_file), + patch("orcapod.hashing.core.hash_pathset", side_effect=mock_hash_pathset), + patch("orcapod.hashing.core.hash_packet", side_effect=mock_hash_packet), + ): + yield + + +def test_default_composite_hasher_implements_all_protocols(): + """Test that CompositeFileHasher implements all three protocols.""" + # Create a basic file hasher to be used within the composite hasher + file_hasher = BasicFileHasher() + + # Create the composite hasher + composite_hasher = LegacyDefaultCompositeFileHasher(file_hasher) + + # Verify it implements all three protocols + assert isinstance(composite_hasher, LegacyFileHasher) + assert isinstance(composite_hasher, LegacyPathSetHasher) + assert isinstance(composite_hasher, LegacyPacketHasher) + + +def test_default_composite_hasher_file_hashing(): + """Test CompositeFileHasher's file hashing functionality.""" + # We can use a mock path since our mocks don't require real files + file_path = "/path/to/mock_file.txt" + + # Create a custom mock file hasher + class MockFileHasher: + def hash_file(self, file_path): + return mock_hash_file(file_path) + + file_hasher = MockFileHasher() + composite_hasher = LegacyDefaultCompositeFileHasher(file_hasher) + + # Get hash from the composite hasher and directly from the file hasher + direct_hash = file_hasher.hash_file(file_path) + composite_hash = composite_hasher.hash_file(file_path) + + # The hashes should be identical + assert direct_hash == composite_hash + + +def test_default_composite_hasher_pathset_hashing(): + """Test CompositeFileHasher's path set hashing functionality.""" + + # Create a custom mock file hasher that doesn't check for file existence + class MockFileHasher: + def hash_file(self, file_path) -> str: + return mock_hash_file(file_path) + + file_hasher = MockFileHasher() + composite_hasher = LegacyDefaultCompositeFileHasher(file_hasher) + + # Simple path set with non-existent paths + pathset = ["/path/to/file1.txt", "/path/to/file2.txt"] + + # Hash the pathset + result = composite_hasher.hash_pathset(pathset) + + # The result should be a string hash + assert isinstance(result, str) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_hashing/test_packet_hasher.py b/tests/test_hashing/test_packet_hasher.py index 69b89d0..80a16ed 100644 --- a/tests/test_hashing/test_packet_hasher.py +++ b/tests/test_hashing/test_packet_hasher.py @@ -3,11 +3,11 @@ import pytest -from orcapod.hashing.file_hashers import DefaultPacketHasher -from orcapod.hashing.types import PathSetHasher +from orcapod.hashing.file_hashers import LegacyDefaultPacketHasher +from orcapod.hashing.types import LegacyPathSetHasher -class MockPathSetHasher(PathSetHasher): +class MockPathSetHasher(LegacyPathSetHasher): """Simple mock PathSetHasher for testing.""" def __init__(self, hash_value="mock_hash"): @@ -19,10 +19,10 @@ def hash_pathset(self, pathset): return f"{self.hash_value}_{pathset}" -def test_default_packet_hasher_empty_packet(): - """Test DefaultPacketHasher with an empty packet.""" +def test_legacy_packet_hasher_empty_packet(): + """Test LegacyPacketHasher with an empty packet.""" pathset_hasher = MockPathSetHasher() - packet_hasher = DefaultPacketHasher(pathset_hasher) + packet_hasher = LegacyDefaultPacketHasher(pathset_hasher) # Test with empty packet packet = {} @@ -36,10 +36,10 @@ def test_default_packet_hasher_empty_packet(): assert isinstance(result, str) -def test_default_packet_hasher_single_entry(): - """Test DefaultPacketHasher with a packet containing a single entry.""" +def test_legacy_packet_hasher_single_entry(): + """Test LegacyPacketHasher with a packet containing a single entry.""" pathset_hasher = MockPathSetHasher() - packet_hasher = DefaultPacketHasher(pathset_hasher) + packet_hasher = LegacyDefaultPacketHasher(pathset_hasher) # Test with a single entry packet = {"input": "/path/to/file.txt"} @@ -54,10 +54,10 @@ def test_default_packet_hasher_single_entry(): assert isinstance(result, str) -def test_default_packet_hasher_multiple_entries(): - """Test DefaultPacketHasher with a packet containing multiple entries.""" +def test_legacy_packet_hasher_multiple_entries(): + """Test LegacyPacketHasher with a packet containing multiple entries.""" pathset_hasher = MockPathSetHasher() - packet_hasher = DefaultPacketHasher(pathset_hasher) + packet_hasher = LegacyDefaultPacketHasher(pathset_hasher) # Test with multiple entries packet = { @@ -78,10 +78,10 @@ def test_default_packet_hasher_multiple_entries(): assert isinstance(result, str) -def test_default_packet_hasher_nested_structure(): - """Test DefaultPacketHasher with a deeply nested packet structure.""" +def test_legacy_packet_hasher_nested_structure(): + """Test LegacyPacketHasher with a deeply nested packet structure.""" pathset_hasher = MockPathSetHasher() - packet_hasher = DefaultPacketHasher(pathset_hasher) + packet_hasher = LegacyDefaultPacketHasher(pathset_hasher) # Test with nested packet structure packet = { @@ -103,16 +103,16 @@ def test_default_packet_hasher_nested_structure(): assert isinstance(result, str) -def test_default_packet_hasher_with_char_count(): - """Test DefaultPacketHasher with different char_count values.""" +def test_legacy_packet_hasher_with_char_count(): + """Test LegacyPacketHasher with different char_count values.""" pathset_hasher = MockPathSetHasher() # Test with default char_count (32) - default_hasher = DefaultPacketHasher(pathset_hasher) + default_hasher = LegacyDefaultPacketHasher(pathset_hasher) default_result = default_hasher.hash_packet({"input": "/path/to/file.txt"}) # Test with custom char_count - custom_hasher = DefaultPacketHasher(pathset_hasher, char_count=16) + custom_hasher = LegacyDefaultPacketHasher(pathset_hasher, char_count=16) custom_result = custom_hasher.hash_packet({"input": "/path/to/file.txt"}) # Results should be different based on char_count diff --git a/tests/test_hashing/test_path_set_hasher.py b/tests/test_hashing/test_path_set_hasher.py index 65e626a..9286f82 100644 --- a/tests/test_hashing/test_path_set_hasher.py +++ b/tests/test_hashing/test_path_set_hasher.py @@ -9,11 +9,11 @@ import pytest import orcapod.hashing.legacy_core -from orcapod.hashing.file_hashers import DefaultPathsetHasher -from orcapod.hashing.types import FileHasher +from orcapod.hashing.file_hashers import LegacyDefaultPathsetHasher +from orcapod.hashing.types import LegacyFileHasher -class MockFileHasher(FileHasher): +class MockFileHasher(LegacyFileHasher): """Simple mock FileHasher for testing.""" def __init__(self, hash_value="mock_hash"): @@ -90,10 +90,10 @@ def patch_hash_pathset(): yield -def test_default_pathset_hasher_single_file(): - """Test DefaultPathsetHasher with a single file path.""" +def test_legacy_pathset_hasher_single_file(): + """Test LegacyPathsetHasher with a single file path.""" file_hasher = MockFileHasher() - pathset_hasher = DefaultPathsetHasher(file_hasher) + pathset_hasher = LegacyDefaultPathsetHasher(file_hasher) # Create a real file for testing file_path = create_temp_file() @@ -116,7 +116,7 @@ def test_default_pathset_hasher_single_file(): def test_default_pathset_hasher_multiple_files(): """Test DefaultPathsetHasher with multiple files in a list.""" file_hasher = MockFileHasher() - pathset_hasher = DefaultPathsetHasher(file_hasher) + pathset_hasher = LegacyDefaultPathsetHasher(file_hasher) # Create real files for testing file_paths = [create_temp_file(f"content {i}") for i in range(3)] @@ -195,7 +195,7 @@ def test_default_pathset_hasher_nested_paths(): def test_default_pathset_hasher_with_nonexistent_files(): """Test DefaultPathsetHasher with both existent and non-existent files.""" file_hasher = MockFileHasher() - pathset_hasher = DefaultPathsetHasher(file_hasher) + pathset_hasher = LegacyDefaultPathsetHasher(file_hasher) # Reset the file_hasher's call list file_hasher.file_hash_calls = [] @@ -249,14 +249,14 @@ def test_default_pathset_hasher_with_char_count(): try: # Test with default char_count (32) - default_hasher = DefaultPathsetHasher(file_hasher) + default_hasher = LegacyDefaultPathsetHasher(file_hasher) default_result = default_hasher.hash_pathset(file_path) # Reset call list file_hasher.file_hash_calls = [] # Test with custom char_count - custom_hasher = DefaultPathsetHasher(file_hasher, char_count=16) + custom_hasher = LegacyDefaultPathsetHasher(file_hasher, char_count=16) custom_result = custom_hasher.hash_pathset(file_path) # Both should have called the file_hasher once diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index 32d8618..d6cc106 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -8,15 +8,15 @@ import pytest from orcapod.hashing.types import ( - CompositeFileHasher, - FileHasher, - PacketHasher, - PathSetHasher, + LegacyCompositeFileHasher, + LegacyFileHasher, + LegacyPacketHasher, + LegacyPathSetHasher, ) -from orcapod.store.dict_data_stores import DirDataStore +from orcapod.stores.dict_data_stores import DirDataStore -class MockFileHasher(FileHasher): +class MockFileHasher(LegacyFileHasher): """Mock FileHasher for testing.""" def __init__(self, hash_value="mock_hash"): @@ -28,19 +28,19 @@ def hash_file(self, file_path): return f"{self.hash_value}_file" -class MockPathSetHasher(PathSetHasher): +class MockPathSetHasher(LegacyPathSetHasher): """Mock PathSetHasher for testing.""" def __init__(self, hash_value="mock_hash"): self.hash_value = hash_value self.pathset_hash_calls = [] - def hash_pathset(self, pathset): + def hash_pathset(self, pathset) -> str: self.pathset_hash_calls.append(pathset) return f"{self.hash_value}_pathset" -class MockPacketHasher(PacketHasher): +class MockPacketHasher(LegacyPacketHasher): """Mock PacketHasher for testing.""" def __init__(self, hash_value="mock_hash"): @@ -52,7 +52,7 @@ def hash_packet(self, packet): return f"{self.hash_value}_packet" -class MockCompositeHasher(CompositeFileHasher): +class MockCompositeHasher(LegacyCompositeFileHasher): """Mock CompositeHasher that implements all three hash protocols.""" def __init__(self, hash_value="mock_hash"): @@ -61,15 +61,15 @@ def __init__(self, hash_value="mock_hash"): self.pathset_hash_calls = [] self.packet_hash_calls = [] - def hash_file(self, file_path): + def hash_file_content(self, file_path): self.file_hash_calls.append(file_path) return f"{self.hash_value}_file" - def hash_pathset(self, pathset): + def hash_pathset(self, pathset) -> str: self.pathset_hash_calls.append(pathset) return f"{self.hash_value}_pathset" - def hash_packet(self, packet): + def hash_packet(self, packet) -> str: self.packet_hash_calls.append(packet) return f"{self.hash_value}_packet" @@ -86,7 +86,7 @@ def test_dir_data_store_init_default_hasher(temp_dir): assert store_dir.is_dir() # Verify the default PacketHasher is used - assert isinstance(store.packet_hasher, PacketHasher) + assert isinstance(store.packet_hasher, LegacyPacketHasher) # Check default parameters assert store.copy_files is True diff --git a/tests/test_store/test_integration.py b/tests/test_store/test_integration.py index 48e0703..00d3b99 100644 --- a/tests/test_store/test_integration.py +++ b/tests/test_store/test_integration.py @@ -9,10 +9,10 @@ from orcapod.hashing.file_hashers import ( BasicFileHasher, CachedFileHasher, - DefaultCompositeFileHasher, + LegacyCompositeFileHasher, ) from orcapod.hashing.string_cachers import InMemoryCacher -from orcapod.store.dict_data_stores import DirDataStore, NoOpDataStore +from orcapod.stores.dict_data_stores import DirDataStore, NoOpDataStore def test_integration_with_cached_file_hasher(temp_dir, sample_files): @@ -28,7 +28,7 @@ def test_integration_with_cached_file_hasher(temp_dir, sample_files): ) # Create a CompositeFileHasher that will use the CachedFileHasher - composite_hasher = DefaultCompositeFileHasher(file_hasher) + composite_hasher = LegacyCompositeFileHasher(file_hasher) # Create the store with CompositeFileHasher store = DirDataStore(store_dir=store_dir, packet_hasher=composite_hasher) diff --git a/tests/test_store/test_noop_data_store.py b/tests/test_store/test_noop_data_store.py index ab0eecd..4ff838f 100644 --- a/tests/test_store/test_noop_data_store.py +++ b/tests/test_store/test_noop_data_store.py @@ -3,7 +3,7 @@ import pytest -from orcapod.store.dict_data_stores import NoOpDataStore +from orcapod.stores.dict_data_stores import NoOpDataStore def test_noop_data_store_memoize(): @@ -43,7 +43,7 @@ def test_noop_data_store_retrieve_memoized(): def test_noop_data_store_is_data_store_subclass(): """Test that NoOpDataStore is a subclass of DataStore.""" - from orcapod.store import DataStore + from orcapod.stores import DataStore store = NoOpDataStore() assert isinstance(store, DataStore) diff --git a/tests/test_store/test_transfer_data_store.py b/tests/test_store/test_transfer_data_store.py index 191da89..21ed4c9 100644 --- a/tests/test_store/test_transfer_data_store.py +++ b/tests/test_store/test_transfer_data_store.py @@ -5,12 +5,12 @@ import pytest -from orcapod.hashing.types import PacketHasher -from orcapod.store.dict_data_stores import DirDataStore, NoOpDataStore -from orcapod.store.transfer_data_store import TransferDataStore +from orcapod.hashing.types import LegacyPacketHasher +from orcapod.stores.dict_data_stores import DirDataStore, NoOpDataStore +from orcapod.stores.transfer_data_store import TransferDataStore -class MockPacketHasher(PacketHasher): +class MockPacketHasher(LegacyPacketHasher): """Mock PacketHasher for testing.""" def __init__(self, hash_value="mock_hash"): From 41f1b63061247d4a88e4063e4a72c0b906207fd3 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 03:58:57 +0000 Subject: [PATCH 30/57] test: fix legacy tests --- tests/test_hashing/test_legacy_composite_hasher.py | 10 +++++----- tests/test_store/test_dir_data_store.py | 2 +- tests/test_store/test_integration.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_hashing/test_legacy_composite_hasher.py b/tests/test_hashing/test_legacy_composite_hasher.py index c9a3ee2..f3a8de4 100644 --- a/tests/test_hashing/test_legacy_composite_hasher.py +++ b/tests/test_hashing/test_legacy_composite_hasher.py @@ -6,7 +6,7 @@ import pytest from orcapod.hashing.legacy_core import hash_to_hex -from orcapod.hashing.file_hashers import BasicFileHasher, LegacyDefaultCompositeFileHasher +from orcapod.hashing.file_hashers import LegacyDefaultFileHasher, LegacyDefaultCompositeFileHasher from orcapod.hashing.types import LegacyFileHasher, LegacyPacketHasher, LegacyPathSetHasher @@ -89,9 +89,9 @@ def mock_hash_packet( def patch_hash_functions(): """Patch the hash functions in the core module for all tests.""" with ( - patch("orcapod.hashing.core.hash_file", side_effect=mock_hash_file), - patch("orcapod.hashing.core.hash_pathset", side_effect=mock_hash_pathset), - patch("orcapod.hashing.core.hash_packet", side_effect=mock_hash_packet), + patch("orcapod.hashing.legacy_core.hash_file", side_effect=mock_hash_file), + patch("orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset), + patch("orcapod.hashing.legacy_core.hash_packet", side_effect=mock_hash_packet), ): yield @@ -99,7 +99,7 @@ def patch_hash_functions(): def test_default_composite_hasher_implements_all_protocols(): """Test that CompositeFileHasher implements all three protocols.""" # Create a basic file hasher to be used within the composite hasher - file_hasher = BasicFileHasher() + file_hasher = LegacyDefaultFileHasher() # Create the composite hasher composite_hasher = LegacyDefaultCompositeFileHasher(file_hasher) diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index d6cc106..eae39eb 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -461,7 +461,7 @@ def test_dir_data_store_with_default_packet_hasher(temp_dir, sample_files): store = DirDataStore(store_dir=store_dir) # Verify that default PacketHasher was created - assert isinstance(store.packet_hasher, PacketHasher) + assert isinstance(store.packet_hasher, LegacyPacketHasher) # Test memoization and retrieval packet = {"input_file": sample_files["input"]["file1"]} diff --git a/tests/test_store/test_integration.py b/tests/test_store/test_integration.py index 00d3b99..2a6e253 100644 --- a/tests/test_store/test_integration.py +++ b/tests/test_store/test_integration.py @@ -9,7 +9,7 @@ from orcapod.hashing.file_hashers import ( BasicFileHasher, CachedFileHasher, - LegacyCompositeFileHasher, + LegacyDefaultCompositeFileHasher, ) from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.stores.dict_data_stores import DirDataStore, NoOpDataStore @@ -28,7 +28,7 @@ def test_integration_with_cached_file_hasher(temp_dir, sample_files): ) # Create a CompositeFileHasher that will use the CachedFileHasher - composite_hasher = LegacyCompositeFileHasher(file_hasher) + composite_hasher = LegacyDefaultCompositeFileHasher(file_hasher) # Create the store with CompositeFileHasher store = DirDataStore(store_dir=store_dir, packet_hasher=composite_hasher) From fe423f7abb429d6e5fd78609a2c9fa77457542a2 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 05:39:30 +0000 Subject: [PATCH 31/57] fix: make all tests functional --- tests/test_hashing/test_cached_file_hasher.py | 24 +++++++++---------- tests/test_hashing/test_hasher_factory.py | 4 ++-- tests/test_hashing/test_path_set_hasher.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_hashing/test_cached_file_hasher.py b/tests/test_hashing/test_cached_file_hasher.py index d7514a5..8b9ce30 100644 --- a/tests/test_hashing/test_cached_file_hasher.py +++ b/tests/test_hashing/test_cached_file_hasher.py @@ -10,8 +10,8 @@ import pytest from orcapod.hashing.file_hashers import ( - BasicFileHasher, - CachedFileHasher, + LegacyDefaultFileHasher, + LegacyCachedFileHasher, ) from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.hashing.types import LegacyFileHasher, StringCacher @@ -73,10 +73,10 @@ def load_packet_hash_lut(): def test_cached_file_hasher_construction(): """Test that CachedFileHasher can be constructed with various parameters.""" # Test with default parameters - file_hasher = BasicFileHasher() + file_hasher = LegacyDefaultFileHasher() string_cacher = InMemoryCacher() - cached_hasher1 = CachedFileHasher(file_hasher, string_cacher) + cached_hasher1 = LegacyCachedFileHasher(file_hasher, string_cacher) assert cached_hasher1.file_hasher == file_hasher assert cached_hasher1.string_cacher == string_cacher @@ -99,8 +99,8 @@ def test_cached_file_hasher_file_caching(): mock_string_cacher = MagicMock(spec=StringCacher) mock_string_cacher.get_cached.return_value = None # Initially no cached value - file_hasher = BasicFileHasher() - cached_hasher = CachedFileHasher(file_hasher, mock_string_cacher) + file_hasher = LegacyDefaultFileHasher() + cached_hasher = LegacyCachedFileHasher(file_hasher, mock_string_cacher) # First call should compute the hash and cache it result1 = cached_hasher.hash_file(file_path) @@ -143,7 +143,7 @@ def test_cached_file_hasher_call_counts(): string_cacher = InMemoryCacher() # Create the cached file hasher with all caching enabled - cached_hasher = CachedFileHasher( + cached_hasher = LegacyCachedFileHasher( mock_file_hasher, string_cacher, ) @@ -181,11 +181,11 @@ def test_cached_file_hasher_performance(): file_path = verify_path_exists(info["file"]) # Setup non-cached hasher - file_hasher = BasicFileHasher() + file_hasher = LegacyDefaultFileHasher() # Setup cached hasher string_cacher = InMemoryCacher() - cached_hasher = CachedFileHasher(file_hasher, string_cacher) + cached_hasher = LegacyCachedFileHasher(file_hasher, string_cacher) # Measure time for multiple hash operations with non-cached hasher start_time = time.time() @@ -221,11 +221,11 @@ def test_cached_file_hasher_with_different_cachers(): try: file_path = temp_file.name - file_hasher = BasicFileHasher() + file_hasher = LegacyDefaultFileHasher() # Test with InMemoryCacher mem_cacher = InMemoryCacher(max_size=10) - cached_hasher1 = CachedFileHasher(file_hasher, mem_cacher) + cached_hasher1 = LegacyCachedFileHasher(file_hasher, mem_cacher) # First hash call hash1 = cached_hasher1.hash_file(file_path) @@ -249,7 +249,7 @@ def clear_cache(self) -> None: self.storage.clear() custom_cacher = CustomCacher() - cached_hasher2 = CachedFileHasher(file_hasher, custom_cacher) + cached_hasher2 = LegacyCachedFileHasher(file_hasher, custom_cacher) # Get hash with custom cacher hash2 = cached_hasher2.hash_file(file_path) diff --git a/tests/test_hashing/test_hasher_factory.py b/tests/test_hashing/test_hasher_factory.py index 5776a2d..69804a3 100644 --- a/tests/test_hashing/test_hasher_factory.py +++ b/tests/test_hashing/test_hasher_factory.py @@ -80,7 +80,7 @@ def test_create_file_hasher_custom_buffer_size(self): def test_create_file_hasher_all_custom_parameters(self): """Test creating file hasher with all custom parameters.""" cacher = InMemoryCacher(max_size=500) - hasher = LegacyPathLikeHasherFactory.create_file_hasher( + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( string_cacher=cacher, algorithm="blake2b", buffer_size=16384 ) @@ -94,7 +94,7 @@ def test_create_file_hasher_different_cacher_types(self): """Test creating file hasher with different types of string cachers.""" # InMemoryCacher memory_cacher = InMemoryCacher() - hasher1 = LegacyPathLikeHasherFactory.create_file_hasher(string_cacher=memory_cacher) + hasher1 = LegacyPathLikeHasherFactory.create_legacy_file_hasher(string_cacher=memory_cacher) assert isinstance(hasher1, LegacyCachedFileHasher) assert hasher1.string_cacher is memory_cacher diff --git a/tests/test_hashing/test_path_set_hasher.py b/tests/test_hashing/test_path_set_hasher.py index 9286f82..0a48acb 100644 --- a/tests/test_hashing/test_path_set_hasher.py +++ b/tests/test_hashing/test_path_set_hasher.py @@ -86,7 +86,7 @@ def mock_hash_pathset( @pytest.fixture(autouse=True) def patch_hash_pathset(): """Patch the hash_pathset function in the hashing module for all tests.""" - with patch("orcapod.hashing.core.hash_pathset", side_effect=mock_hash_pathset): + with patch("orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset): yield @@ -225,7 +225,7 @@ def custom_hash_nonexistent(pathset, **kwargs): # Patch hash_pathset just for this test with patch( - "orcapod.hashing.core.hash_pathset", side_effect=custom_hash_nonexistent + "orcapod.hashing.legacy_core.hash_pathset", side_effect=custom_hash_nonexistent ): result = pathset_hasher.hash_pathset(pathset) From ba1f45d6096a504096b97c5d5673cb2e851133d9 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 06:53:23 +0000 Subject: [PATCH 32/57] refactor: cleanup imports and use versioned object hasher --- src/orcapod/core/base.py | 4 -- src/orcapod/core/operators.py | 1 - src/orcapod/core/pod.py | 2 +- src/orcapod/core/sources.py | 2 +- src/orcapod/hashing/__init__.py | 13 ----- src/orcapod/hashing/arrow_hashers.py | 5 +- src/orcapod/hashing/defaults.py | 37 ++++++------- src/orcapod/hashing/object_hashers.py | 7 ++- src/orcapod/hashing/types.py | 16 ++++-- src/orcapod/hashing/versioned_hashers.py | 53 ++++++++++++++++--- .../pipeline/{wrappers.py => nodes.py} | 22 ++++---- src/orcapod/pipeline/pipeline.py | 3 +- src/orcapod/stores/dict_data_stores.py | 2 +- src/orcapod/utils/object_spec.py | 19 +++++++ 14 files changed, 120 insertions(+), 66 deletions(-) rename src/orcapod/pipeline/{wrappers.py => nodes.py} (97%) create mode 100644 src/orcapod/utils/object_spec.py diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 7c9a299..f0d5362 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -4,10 +4,6 @@ from collections.abc import Callable, Collection, Iterator from typing import Any - -from orcapod.hashing import HashableMixin, ObjectHasher -from orcapod.hashing import get_default_object_hasher - from orcapod.hashing import ContentIdentifiableBase from orcapod.types import Packet, Tag, TypeSpec from orcapod.types.typespec_utils import get_typespec_from_dict diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index c26dc2d..598f2e3 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -5,7 +5,6 @@ from orcapod.types import Packet, Tag, TypeSpec from orcapod.types.typespec_utils import union_typespecs, intersection_typespecs -from orcapod.hashing import function_content_hash, hash_function from orcapod.core.base import Kernel, SyncStream, Operator from orcapod.core.streams import SyncStreamFromGenerator from orcapod.utils.stream_utils import ( diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index 4271887..ae6778d 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -13,8 +13,8 @@ from orcapod.hashing import ( FunctionInfoExtractor, - get_function_signature, ) +from orcapod.hashing.legacy_core import get_function_signature from orcapod.core import Kernel from orcapod.core.operators import Join from orcapod.core.streams import ( diff --git a/src/orcapod/core/sources.py b/src/orcapod/core/sources.py index 21adae9..3d79e7a 100644 --- a/src/orcapod/core/sources.py +++ b/src/orcapod/core/sources.py @@ -4,7 +4,7 @@ from typing import Any, Literal from orcapod.core.base import Source -from orcapod.hashing import hash_function +from orcapod.hashing.legacy_core import hash_function from orcapod.core.streams import SyncStream, SyncStreamFromGenerator from orcapod.types import Packet, Tag diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index 7aaf11b..b1e5849 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -1,17 +1,4 @@ -from .legacy_core import ( - HashableMixin, - function_content_hash, - get_function_signature, - hash_file, - hash_function, - hash_packet, - hash_pathset, - hash_to_hex, - hash_to_int, - hash_to_uuid, -) from .defaults import ( - get_default_composite_file_hasher, get_default_object_hasher, get_default_arrow_hasher, ) diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index c50ebfc..3545911 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -73,7 +73,6 @@ def set_cacher(self, semantic_type: str, cacher: StringCacher) -> None: This is a no-op for SemanticArrowHasher since it hashes column contents directly. """ - # SemanticArrowHasher does not use string caching, so this is a no-op if semantic_type in self.semantic_type_hashers: self.semantic_type_hashers[semantic_type].set_cacher(cacher) else: @@ -179,7 +178,7 @@ def _serialize_table_ipc(self, table: pa.Table) -> bytes: return buffer.getvalue() - def hash_table(self, table: pa.Table, add_prefix: bool = True) -> str: + def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: """ Compute stable hash of Arrow table. @@ -208,7 +207,7 @@ def hash_table(self, table: pa.Table, add_prefix: bool = True) -> str: hasher.update(serialized_bytes) hash_str = hasher.hexdigest() - if add_prefix: + if prefix_hasher_id: hash_str = f"{self.get_hasher_id()}:{hash_str}" return hash_str diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 8ba7c0b..f9dee37 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -9,11 +9,11 @@ from orcapod.hashing.file_hashers import BasicFileHasher, LegacyPathLikeHasherFactory from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.hashing.object_hashers import ObjectHasher -from orcapod.hashing.object_hashers import DefaultObjectHasher, LegacyObjectHasher +from orcapod.hashing.object_hashers import LegacyObjectHasher from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory from orcapod.hashing.arrow_hashers import SemanticArrowHasher from orcapod.hashing.semantic_type_hashers import PathHasher -from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher +from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher, get_versioned_object_hasher def get_default_arrow_hasher( @@ -36,6 +36,21 @@ def get_default_arrow_hasher( return arrow_hasher +def get_default_object_hasher() -> ObjectHasher: + object_hasher = get_versioned_object_hasher() + return object_hasher + + + +def get_legacy_object_hasher() -> ObjectHasher: + function_info_extractor = ( + FunctionInfoExtractorFactory.create_function_info_extractor( + strategy="signature" + ) + ) + return LegacyObjectHasher(function_info_extractor=function_info_extractor) + + def get_default_composite_file_hasher(with_cache=True) -> LegacyCompositeFileHasher: if with_cache: # use unlimited caching @@ -48,21 +63,3 @@ def get_default_composite_file_hasher_with_cacher(cacher=None) -> LegacyComposit if cacher is None: cacher = InMemoryCacher(max_size=None) return LegacyPathLikeHasherFactory.create_cached_legacy_composite(cacher) - - -def get_default_object_hasher() -> ObjectHasher: - function_info_extractor = ( - FunctionInfoExtractorFactory.create_function_info_extractor( - strategy="signature" - ) - ) - return DefaultObjectHasher(function_info_extractor=function_info_extractor) - - -def get_legacy_object_hasher() -> ObjectHasher: - function_info_extractor = ( - FunctionInfoExtractorFactory.create_function_info_extractor( - strategy="signature" - ) - ) - return LegacyObjectHasher(function_info_extractor=function_info_extractor) diff --git a/src/orcapod/hashing/object_hashers.py b/src/orcapod/hashing/object_hashers.py index 7e35ccb..bdd0169 100644 --- a/src/orcapod/hashing/object_hashers.py +++ b/src/orcapod/hashing/object_hashers.py @@ -4,17 +4,22 @@ from .hash_utils import hash_object -class DefaultObjectHasher(ObjectHasher): +class BasicObjectHasher(ObjectHasher): """ Default object hasher used throughout the codebase. """ def __init__( self, + hasher_id: str, function_info_extractor: FunctionInfoExtractor | None = None, ): + self._hasher_id = hasher_id self.function_info_extractor = function_info_extractor + def get_hasher_id(self) -> str: + return self._hasher_id + def hash(self, obj: object) -> bytes: """ Hash an object to a byte representation. diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index c7d79da..24afdbc 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -43,7 +43,13 @@ def hash(self, obj: Any) -> bytes: """ ... - def hash_to_hex(self, obj: Any, char_count: int | None = None) -> str: + @abstractmethod + def get_hasher_id(self) -> str: + """ + Returns a unique identifier/name assigned to the hasher + """ + + def hash_to_hex(self, obj: Any, char_count: int | None = None, prefix_hasher_id:bool=False) -> str: hash_bytes = self.hash(obj) hex_str = hash_bytes.hex() @@ -53,7 +59,9 @@ def hash_to_hex(self, obj: Any, char_count: int | None = None) -> str: raise ValueError( f"Cannot truncate to {char_count} chars, hash only has {len(hex_str)}" ) - return hex_str[:char_count] + hex_str = hex_str[:char_count] + if prefix_hasher_id: + hex_str = self.get_hasher_id() + ":" + hex_str return hex_str def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: @@ -74,7 +82,6 @@ def hash_to_uuid( self, obj: Any, namespace: uuid.UUID = uuid.NAMESPACE_OID ) -> uuid.UUID: """Convert hash to proper UUID5.""" - # Use the hex representation as input to UUID5 return uuid.uuid5(namespace, self.hash(obj)) @@ -88,8 +95,9 @@ def hash_file(self, file_path: PathLike) -> bytes: ... @runtime_checkable class ArrowHasher(Protocol): """Protocol for hashing arrow packets.""" + def get_hasher_id(self) -> str: ... - def hash_table(self, table: pa.Table, add_prefix: bool = True) -> str: ... + def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: ... @runtime_checkable diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py index 22c715e..18c8680 100644 --- a/src/orcapod/hashing/versioned_hashers.py +++ b/src/orcapod/hashing/versioned_hashers.py @@ -1,15 +1,16 @@ # A collection of versioned hashers that provide a "default" implementation of hashers. from .arrow_hashers import SemanticArrowHasher +from .types import ObjectHasher, ArrowHasher import importlib from typing import Any CURRENT_VERSION = "v0.1" -versioned_hashers = { +versioned_semantic_arrow_hashers = { "v0.1": { "_class": "orcapod.hashing.arrow_hashers.SemanticArrowHasher", "config": { - "hasher_id": "default_v0.1", + "hasher_id": "arrow_v0.1", "hash_algorithm": "sha256", "chunk_size": 8192, "semantic_type_hashers": { @@ -29,6 +30,23 @@ } } +versioned_object_hashers = { + "v0.1": { + "_class": "orcapod.hashing.object_hashers.BasicObjectHasher", + "config": { + "hasher_id": "object_v0.1", + "function_info_extractor" : { + "_class": "orcapod.hashing.function_info_extractors.FunctionSignatureExtractor", + "config": { + "include_module": True, + "include_defaults": True + } + + } + } + } +} + def parse_objectspec(obj_spec: dict) -> Any: if "_class" in obj_spec: @@ -51,7 +69,7 @@ def parse_objectspec(obj_spec: dict) -> Any: def get_versioned_semantic_arrow_hasher( version: str | None = None, -) -> SemanticArrowHasher: +) -> ArrowHasher: """ Get the versioned hasher for the specified version. @@ -59,13 +77,36 @@ def get_versioned_semantic_arrow_hasher( version (str): The version of the hasher to retrieve. Returns: - SemanticArrowHasher: An instance of the hasher for the specified version. + ArrowHasher: An instance of the arrow hasher of the specified version. + """ + if version is None: + version = CURRENT_VERSION + + if version not in versioned_semantic_arrow_hashers: + raise ValueError(f"Unsupported hasher version: {version}") + + hasher_spec = versioned_semantic_arrow_hashers[version] + return parse_objectspec(hasher_spec) + + +def get_versioned_object_hasher( + version: str | None = None, +) -> ObjectHasher: + """ + Get an object hasher for the specified version. + + Args: + version (str): The version of the hasher to retrieve. + + Returns: + Object: An instance of the object hasher of the specified version. """ if version is None: version = CURRENT_VERSION - if version not in versioned_hashers: + if version not in versioned_object_hashers: raise ValueError(f"Unsupported hasher version: {version}") - hasher_spec = versioned_hashers[version] + hasher_spec = versioned_object_hashers[version] return parse_objectspec(hasher_spec) + diff --git a/src/orcapod/pipeline/wrappers.py b/src/orcapod/pipeline/nodes.py similarity index 97% rename from src/orcapod/pipeline/wrappers.py rename to src/orcapod/pipeline/nodes.py index 720609e..2ac3763 100644 --- a/src/orcapod/pipeline/wrappers.py +++ b/src/orcapod/pipeline/nodes.py @@ -149,7 +149,6 @@ def post_call(self, tag: Tag, packet: Packet) -> None: ... def output_iterator_completion_hook(self) -> None: ... - class CachedKernelWrapper(KernelInvocationWrapper, Source): """ A Kernel wrapper that wraps a kernel and stores the outputs of the kernel. @@ -216,7 +215,7 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): self.update_cached_values() def update_cached_values(self): - self.source_info = self.store_path_prefix + (self.label, self.kernel_hasher.hash_to_hex(self.kernel)) + self.source_info = self.store_path_prefix + (self.label, self.kernel_hasher.hash_to_hex(self.kernel, prefix_hasher_id=True)) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.packet_typespec is None: @@ -271,7 +270,7 @@ def post_call(self, tag: Tag, packet: Packet) -> None: merged_info = {**tag, **packet.get_composite()} output_table = self.output_converter.from_python_packet_to_arrow_table(merged_info) # TODO: revisit this logic - output_id = self.arrow_hasher.hash_table(output_table) + output_id = self.arrow_hasher.hash_table(output_table, prefix_hasher_id=True) if not self.output_store.get_record(self.source_info, output_id): self.output_store.add_record( self.source_info, @@ -425,13 +424,18 @@ def registry(self, registry: SemanticTypeRegistry | None = None): self.update_cached_values() def update_cached_values(self) -> None: - self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) + self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod, prefix_hasher_id=True) + self.input_typespec, self.output_typespec = self.function_pod.get_function_typespecs() self.tag_keys, self.output_keys = self.keys(trigger_run=False) + + + if self.tag_keys is None or self.output_keys is None: raise ValueError( "Currently, cached function pod wrapper can only work with function pods that have keys defined." ) - self.all_keys = tuple(self.tag_keys) + tuple(self.output_keys) + self.tag_keys = tuple(self.tag_keys) + self.output_keys = tuple(self.output_keys) self.tag_typespec, self.output_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.output_typespec is None: raise ValueError( @@ -475,7 +479,7 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: return super().forward(*streams, **kwargs) def get_packet_key(self, packet: Packet) -> str: - return self.arrow_hasher.hash_table(self.input_converter.from_python_packet_to_arrow_table(packet)) + return self.arrow_hasher.hash_table(self.input_converter.from_python_packet_to_arrow_table(packet), prefix_hasher_id=True) @property def source_info(self): @@ -502,7 +506,7 @@ def _add_pipeline_record_with_packet_key(self, tag: Tag, packet_key: str, packet table = self.tag_record_converter.from_python_packet_to_arrow_table(combined_info) - entry_hash = self.arrow_hasher.hash_table(table) + entry_hash = self.arrow_hasher.hash_table(table, prefix_hasher_id=True) # TODO: add error handling # check if record already exists: @@ -658,8 +662,8 @@ def get_all_entries_with_tags(self, keep_hidden_fields: bool = False) -> pl.Lazy ["__packet_key"] ) if not keep_hidden_fields: - pl_df = pl_df.select(self.all_keys) - return pl_df + pl_df = pl_df.select(self.tag_keys + self.output_keys) + return pl_df.lazy() @property def df(self) -> pl.DataFrame | None: diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 864f649..394a454 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -10,9 +10,8 @@ from orcapod.core import Invocation, Kernel, SyncStream from orcapod.core.pod import FunctionPod -from orcapod.pipeline.wrappers import KernelNode, FunctionPodNode, Node +from orcapod.pipeline.nodes import KernelNode, FunctionPodNode, Node -from orcapod.hashing import hash_to_hex from orcapod.core.tracker import GraphTracker from orcapod.stores import ArrowDataStore diff --git a/src/orcapod/stores/dict_data_stores.py b/src/orcapod/stores/dict_data_stores.py index edb44e5..c4eff60 100644 --- a/src/orcapod/stores/dict_data_stores.py +++ b/src/orcapod/stores/dict_data_stores.py @@ -4,7 +4,7 @@ from os import PathLike from pathlib import Path -from orcapod.hashing import hash_packet +from orcapod.hashing.legacy_core import hash_packet from orcapod.hashing.types import LegacyPacketHasher from orcapod.hashing.defaults import get_default_composite_file_hasher from orcapod.stores.types import DataStore diff --git a/src/orcapod/utils/object_spec.py b/src/orcapod/utils/object_spec.py new file mode 100644 index 0000000..f359a8c --- /dev/null +++ b/src/orcapod/utils/object_spec.py @@ -0,0 +1,19 @@ +import importlib + +def parse_objectspec(obj_spec: dict) -> Any: + if "_class" in obj_spec: + # if _class is specified, treat the dict as an object specification + module_name, class_name = obj_spec["_class"].rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + configs = parse_objectspec(obj_spec.get("config", {})) + return cls(**configs) + else: + # otherwise, parse through the dictionary recursively + parsed_object = obj_spec + for k, v in obj_spec.items(): + if isinstance(v, dict): + parsed_object[k] = parse_objectspec(v) + else: + parsed_object[k] = v + return parsed_object \ No newline at end of file From e689d0dad0a27e505755cb80d816e1536002d7ee Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 07:09:12 +0000 Subject: [PATCH 33/57] fix: failure to reset cache due to mro mixup --- src/orcapod/pipeline/nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index 2ac3763..ddc38b5 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -747,14 +747,14 @@ def __init__(self, kernel: Kernel, input_nodes: Collection["Node"], **kwargs): def reset_cache(self) -> None: ... -class KernelNode(Node, CachedKernelWrapper): +class KernelNode(CachedKernelWrapper, Node): """ A node that wraps a Kernel and provides a Node interface. This is useful for creating nodes in a pipeline that can be executed. """ -class FunctionPodNode(Node, CachedFunctionPodWrapper): +class FunctionPodNode(CachedFunctionPodWrapper, Node): """ A node that wraps a FunctionPod and provides a Node interface. This is useful for creating nodes in a pipeline that can be executed. From 62220649833f7c9b7543e928e08ae9ec0426ba12 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 16:22:30 +0000 Subject: [PATCH 34/57] style: apply ruff format --- src/orcapod/core/base.py | 4 +- src/orcapod/core/operators.py | 21 +- src/orcapod/core/pod.py | 18 +- src/orcapod/core/streams.py | 1 - src/orcapod/core/tracker.py | 29 ++- src/orcapod/hashing/defaults.py | 10 +- src/orcapod/hashing/file_hashers.py | 14 +- src/orcapod/hashing/types.py | 15 +- src/orcapod/hashing/versioned_hashers.py | 15 +- src/orcapod/pipeline/__init__.py | 2 +- src/orcapod/pipeline/nodes.py | 140 ++++++++--- src/orcapod/pipeline/pipeline.py | 7 +- src/orcapod/stores/arrow_data_stores.py | 14 +- .../stores/delta_table_arrow_data_store.py | 228 +++++++++--------- src/orcapod/stores/optimized_memory_store.py | 120 ++++----- src/orcapod/stores/transfer_data_store.py | 4 +- src/orcapod/stores/types.py | 6 +- src/orcapod/types/core.py | 6 +- src/orcapod/types/packet_converter.py | 13 +- src/orcapod/types/packets.py | 141 +++++++---- src/orcapod/types/schemas.py | 115 +++++---- src/orcapod/types/semantic_type_registry.py | 48 ++-- src/orcapod/types/typespec_utils.py | 14 +- src/orcapod/utils/object_spec.py | 3 +- src/orcapod/utils/stream_utils.py | 9 +- .../test_basic_composite_hasher.py | 24 +- tests/test_hashing/test_hasher_factory.py | 24 +- .../test_legacy_composite_hasher.py | 15 +- tests/test_hashing/test_path_set_hasher.py | 7 +- .../test_string_cacher/test_redis_cacher.py | 4 +- 30 files changed, 643 insertions(+), 428 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index f0d5362..2144c34 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -386,7 +386,9 @@ def types(self, *, trigger_run=False) -> tuple[TypeSpec | None, TypeSpec | None] # otherwise, use the keys from the first packet in the stream # note that this may be computationally expensive tag, packet = next(iter(self)) - return tag_types or get_typespec_from_dict(tag), packet_types or get_typespec_from_dict(packet) + return tag_types or get_typespec_from_dict( + tag + ), packet_types or get_typespec_from_dict(packet) def claims_unique_tags(self, *, trigger_run=False) -> bool | None: """ diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index 598f2e3..bcf63e3 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -13,11 +13,10 @@ check_packet_compatibility, join_tags, semijoin_tags, - fill_missing + fill_missing, ) - class Repeat(Operator): """ A Mapper that repeats the packets in the stream a specified number of times. @@ -185,6 +184,7 @@ def claims_unique_tags( return True + def union_lists(left, right): if left is None or right is None: return None @@ -193,7 +193,7 @@ def union_lists(left, right): if item not in output: output.append(item) return output - + class Join(Operator): def identity_structure(self, *streams): @@ -423,7 +423,7 @@ def keys( stream = streams[0] tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) if tag_keys is None or packet_keys is None: - super_tag_keys, super_packet_keys = super().keys(trigger_run=trigger_run) + super_tag_keys, super_packet_keys = super().keys(trigger_run=trigger_run) tag_keys = tag_keys or super_tag_keys packet_keys = packet_keys or super_packet_keys @@ -583,10 +583,12 @@ def keys( return mapped_tag_keys, packet_keys + class SemiJoin(Operator): """ Perform semi-join on the left stream tags with the tags of the right stream """ + def identity_structure(self, *streams): # Restrict DOES depend on the order of the streams -- maintain as a tuple return (self.__class__.__name__,) + streams @@ -625,7 +627,9 @@ def forward(self, *streams: SyncStream) -> SyncStream: left_tag_typespec, left_packet_typespec = left_stream.types() right_tag_typespec, right_packet_typespec = right_stream.types() - common_tag_typespec = intersection_typespecs(left_tag_typespec, right_tag_typespec) + common_tag_typespec = intersection_typespecs( + left_tag_typespec, right_tag_typespec + ) common_tag_keys = None if common_tag_typespec is not None: common_tag_keys = list(common_tag_typespec.keys()) @@ -646,6 +650,7 @@ def generator() -> Iterator[tuple[Tag, Packet]]: def __repr__(self) -> str: return "SemiJoin()" + class Filter(Operator): """ A Mapper that filters the packets in the stream based on a predicate function. @@ -848,9 +853,9 @@ def generator() -> Iterator[tuple[Tag, Packet]]: if k not in new_tag: new_tag[k] = [t.get(k, None) for t, _ in packets] # combine all packets into a single packet - combined_packet: Packet = Packet({ - k: [p.get(k, None) for _, p in packets] for k in packet_keys - }) + combined_packet: Packet = Packet( + {k: [p.get(k, None) for _, p in packets] for k in packet_keys} + ) yield new_tag, combined_packet return SyncStreamFromGenerator(generator) diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index ae6778d..d64bafa 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -52,7 +52,6 @@ def set_active(self, active: bool) -> None: """ self._active = active - def process_stream(self, *streams: SyncStream) -> tuple[SyncStream, ...]: """ Prepare the incoming streams for execution in the pod. This default implementation @@ -72,7 +71,7 @@ def pre_forward_hook( self, *streams: SyncStream, **kwargs ) -> tuple[SyncStream, ...]: return self.process_stream(*streams) - + def generator_completion_hook(self, n_computed: int) -> None: """ Hook that is called when the generator is completed. This can be used to @@ -215,7 +214,9 @@ def __init__( ) ) - self.input_converter = PacketConverter(self.function_input_typespec, self.registry) + self.input_converter = PacketConverter( + self.function_input_typespec, self.registry + ) self.output_converter = PacketConverter( self.function_output_typespec, self.registry ) @@ -223,13 +224,16 @@ def __init__( def get_function_typespecs(self) -> tuple[TypeSpec, TypeSpec]: return self.function_input_typespec, self.function_output_typespec - def __repr__(self) -> str: return f"FunctionPod:{self.function!r}" def __str__(self) -> str: include_module = self.function.__module__ != "__main__" - func_sig = get_function_signature(self.function, name_override=self.function_name, include_module=include_module) + func_sig = get_function_signature( + self.function, + name_override=self.function_name, + include_module=include_module, + ) return f"FunctionPod:{func_sig}" def call(self, tag, packet) -> tuple[Tag, Packet | None]: @@ -258,7 +262,9 @@ def call(self, tag, packet) -> tuple[Tag, Packet | None]: f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" ) - output_packet: Packet = Packet({k: v for k, v in zip(self.output_keys, output_values)}) + output_packet: Packet = Packet( + {k: v for k, v in zip(self.output_keys, output_values)} + ) return tag, output_packet def identity_structure(self, *streams) -> Any: diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index 21060b1..243a1f4 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -104,4 +104,3 @@ def keys( return super().keys(trigger_run=trigger_run) # If the keys are already set, return them return self.tag_keys.copy(), self.packet_keys.copy() - \ No newline at end of file diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py index 337c027..8f07ae3 100644 --- a/src/orcapod/core/tracker.py +++ b/src/orcapod/core/tracker.py @@ -3,6 +3,7 @@ from collections.abc import Collection, Iterator from typing import Any + class StreamWrapper(SyncStream): """ A wrapper for a SyncStream that allows it to be used as a Source. @@ -14,12 +15,16 @@ def __init__(self, stream: SyncStream, **kwargs): super().__init__(**kwargs) self.stream = stream - def keys(self, *streams: SyncStream, **kwargs) -> tuple[Collection[str]|None, Collection[str]|None]: + def keys( + self, *streams: SyncStream, **kwargs + ) -> tuple[Collection[str] | None, Collection[str] | None]: return self.stream.keys(*streams, **kwargs) - def types(self, *streams: SyncStream, **kwargs) -> tuple[TypeSpec|None, TypeSpec|None]: + def types( + self, *streams: SyncStream, **kwargs + ) -> tuple[TypeSpec | None, TypeSpec | None]: return self.stream.types(*streams, **kwargs) - + def computed_label(self) -> str | None: return self.stream.label @@ -28,8 +33,7 @@ def __iter__(self) -> Iterator[tuple[Tag, Packet]]: Iterate over the stream, yielding tuples of (tags, packets). """ yield from self.stream - - + class StreamSource(Source): def __init__(self, stream: SyncStream, **kwargs): @@ -43,25 +47,28 @@ def forward(self, *streams: SyncStream) -> SyncStream: "It generates its own stream from the file system." ) return StreamWrapper(self.stream) - + def identity_structure(self, *streams) -> Any: if len(streams) != 0: raise ValueError( "StreamSource does not support forwarding streams. " "It generates its own stream from the file system." ) - + return (self.__class__.__name__, self.stream) - def types(self, *streams: SyncStream, **kwargs) -> tuple[TypeSpec|None, TypeSpec|None]: + def types( + self, *streams: SyncStream, **kwargs + ) -> tuple[TypeSpec | None, TypeSpec | None]: return self.stream.types() - - def keys(self, *streams: SyncStream, **kwargs) -> tuple[Collection[str]|None, Collection[str]|None]: + + def keys( + self, *streams: SyncStream, **kwargs + ) -> tuple[Collection[str] | None, Collection[str] | None]: return self.stream.keys() def computed_label(self) -> str | None: return self.stream.label - class GraphTracker(Tracker): diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index f9dee37..a9aebcd 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -13,7 +13,10 @@ from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory from orcapod.hashing.arrow_hashers import SemanticArrowHasher from orcapod.hashing.semantic_type_hashers import PathHasher -from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher, get_versioned_object_hasher +from orcapod.hashing.versioned_hashers import ( + get_versioned_semantic_arrow_hasher, + get_versioned_object_hasher, +) def get_default_arrow_hasher( @@ -39,7 +42,6 @@ def get_default_arrow_hasher( def get_default_object_hasher() -> ObjectHasher: object_hasher = get_versioned_object_hasher() return object_hasher - def get_legacy_object_hasher() -> ObjectHasher: @@ -59,7 +61,9 @@ def get_default_composite_file_hasher(with_cache=True) -> LegacyCompositeFileHas return LegacyPathLikeHasherFactory.create_basic_legacy_composite() -def get_default_composite_file_hasher_with_cacher(cacher=None) -> LegacyCompositeFileHasher: +def get_default_composite_file_hasher_with_cacher( + cacher=None, +) -> LegacyCompositeFileHasher: if cacher is None: cacher = InMemoryCacher(max_size=None) return LegacyPathLikeHasherFactory.create_cached_legacy_composite(cacher) diff --git a/src/orcapod/hashing/file_hashers.py b/src/orcapod/hashing/file_hashers.py index 64f48f8..f0ca8d1 100644 --- a/src/orcapod/hashing/file_hashers.py +++ b/src/orcapod/hashing/file_hashers.py @@ -67,7 +67,6 @@ def hash_file(self, file_path: PathLike) -> str: ) - class LegacyCachedFileHasher: """File hasher with caching.""" @@ -90,7 +89,6 @@ def hash_file(self, file_path: PathLike) -> str: return value - class LegacyDefaultPathsetHasher: """Default pathset hasher that composes file hashing.""" @@ -107,11 +105,11 @@ def _hash_file_to_hex(self, file_path: PathLike) -> str: def hash_pathset(self, pathset: PathSet) -> str: """Hash a pathset using the injected file hasher.""" - return legacy_core.hash_pathset( - pathset, - char_count=self.char_count, - file_hasher=self.file_hasher.hash_file, # Inject the method - ) + return legacy_core.hash_pathset( + pathset, + char_count=self.char_count, + file_hasher=self.file_hasher.hash_file, # Inject the method + ) class LegacyDefaultPacketHasher: @@ -197,7 +195,7 @@ def create_cached_legacy_composite( return LegacyDefaultCompositeFileHasher( cached_file_hasher, char_count, packet_prefix=algorithm ) - + @staticmethod def create_legacy_file_hasher( string_cacher: StringCacher | None = None, diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index 24afdbc..fabf812 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -44,12 +44,14 @@ def hash(self, obj: Any) -> bytes: ... @abstractmethod - def get_hasher_id(self) -> str: + def get_hasher_id(self) -> str: """ Returns a unique identifier/name assigned to the hasher """ - def hash_to_hex(self, obj: Any, char_count: int | None = None, prefix_hasher_id:bool=False) -> str: + def hash_to_hex( + self, obj: Any, char_count: int | None = None, prefix_hasher_id: bool = False + ) -> str: hash_bytes = self.hash(obj) hex_str = hash_bytes.hex() @@ -95,6 +97,7 @@ def hash_file(self, file_path: PathLike) -> bytes: ... @runtime_checkable class ArrowHasher(Protocol): """Protocol for hashing arrow packets.""" + def get_hasher_id(self) -> str: ... def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: ... @@ -140,7 +143,7 @@ def set_cacher(self, cacher: StringCacher) -> None: pass -#---------------Legacy implementations and protocols to be deprecated--------------------- +# ---------------Legacy implementations and protocols to be deprecated--------------------- @runtime_checkable @@ -167,9 +170,9 @@ def hash_packet(self, packet: PacketLike) -> str: ... # Combined interface for convenience (optional) @runtime_checkable -class LegacyCompositeFileHasher(LegacyFileHasher, LegacyPathSetHasher, LegacyPacketHasher, Protocol): +class LegacyCompositeFileHasher( + LegacyFileHasher, LegacyPathSetHasher, LegacyPacketHasher, Protocol +): """Combined interface for all file-related hashing operations.""" pass - - diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py index 18c8680..e6095a0 100644 --- a/src/orcapod/hashing/versioned_hashers.py +++ b/src/orcapod/hashing/versioned_hashers.py @@ -35,15 +35,11 @@ "_class": "orcapod.hashing.object_hashers.BasicObjectHasher", "config": { "hasher_id": "object_v0.1", - "function_info_extractor" : { + "function_info_extractor": { "_class": "orcapod.hashing.function_info_extractors.FunctionSignatureExtractor", - "config": { - "include_module": True, - "include_defaults": True - } - - } - } + "config": {"include_module": True, "include_defaults": True}, + }, + }, } } @@ -91,7 +87,7 @@ def get_versioned_semantic_arrow_hasher( def get_versioned_object_hasher( version: str | None = None, -) -> ObjectHasher: +) -> ObjectHasher: """ Get an object hasher for the specified version. @@ -109,4 +105,3 @@ def get_versioned_object_hasher( hasher_spec = versioned_object_hashers[version] return parse_objectspec(hasher_spec) - diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 2bba49b..9a99f89 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -2,4 +2,4 @@ __all__ = [ "Pipeline", -] \ No newline at end of file +] diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index ddc38b5..b5bd54e 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -2,7 +2,11 @@ from orcapod.core import SyncStream, Source, Kernel from orcapod.stores import ArrowDataStore from orcapod.types import Tag, Packet, PacketLike, TypeSpec, default_registry -from orcapod.types.typespec_utils import get_typespec_from_dict, union_typespecs, extract_function_typespecs +from orcapod.types.typespec_utils import ( + get_typespec_from_dict, + union_typespecs, + extract_function_typespecs, +) from orcapod.types.semantic_type_registry import SemanticTypeRegistry from orcapod.types import packets, schemas from orcapod.hashing import ObjectHasher, ArrowHasher @@ -17,12 +21,18 @@ logger = logging.getLogger(__name__) + def get_tag_typespec(tag: Tag) -> dict[str, type]: return {k: str for k in tag} class PolarsSource(Source): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str], packet_keys: Collection[str]|None = None): + def __init__( + self, + df: pl.DataFrame, + tag_keys: Collection[str], + packet_keys: Collection[str] | None = None, + ): self.df = df self.tag_keys = tag_keys self.packet_keys = packet_keys @@ -37,7 +47,12 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: class PolarsStream(SyncStream): - def __init__(self, df: pl.DataFrame, tag_keys: Collection[str], packet_keys: Collection[str] | None = None): + def __init__( + self, + df: pl.DataFrame, + tag_keys: Collection[str], + packet_keys: Collection[str] | None = None, + ): self.df = df self.tag_keys = tuple(tag_keys) self.packet_keys = tuple(packet_keys) if packet_keys is not None else None @@ -48,9 +63,17 @@ def __iter__(self) -> Iterator[tuple[Tag, Packet]]: # df = df.select(self.tag_keys + self.packet_keys) for row in df.iter_rows(named=True): tag = {key: row[key] for key in self.tag_keys} - packet = {key: val for key, val in row.items() if key not in self.tag_keys and not key.startswith("_source_info_")} + packet = { + key: val + for key, val in row.items() + if key not in self.tag_keys and not key.startswith("_source_info_") + } # TODO: revisit and fix this rather hacky implementation - source_info = {key.removeprefix("_source_info_"):val for key, val in row.items() if key.startswith("_source_info_")} + source_info = { + key.removeprefix("_source_info_"): val + for key, val in row.items() + if key.startswith("_source_info_") + } yield tag, Packet(packet, source_info=source_info) @@ -142,8 +165,6 @@ def claims_unique_tags( *resolved_streams, trigger_run=trigger_run ) - - def post_call(self, tag: Tag, packet: Packet) -> None: ... def output_iterator_completion_hook(self) -> None: ... @@ -215,23 +236,31 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): self.update_cached_values() def update_cached_values(self): - self.source_info = self.store_path_prefix + (self.label, self.kernel_hasher.hash_to_hex(self.kernel, prefix_hasher_id=True)) + self.source_info = self.store_path_prefix + ( + self.label, + self.kernel_hasher.hash_to_hex(self.kernel, prefix_hasher_id=True), + ) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.packet_typespec is None: - raise ValueError("Currently, cached kernel wrapper can only work with kernels that have typespecs defined.") + raise ValueError( + "Currently, cached kernel wrapper can only work with kernels that have typespecs defined." + ) # TODO: clean up and make it unnecessary to convert packet typespec packet_schema = schemas.PythonSchema(self.packet_typespec) - joined_typespec = union_typespecs(self.tag_typespec, packet_schema.with_source_info) + joined_typespec = union_typespecs( + self.tag_typespec, packet_schema.with_source_info + ) if joined_typespec is None: raise ValueError( "Joined typespec should not be None. " "This may happen if the tag typespec and packet typespec are incompatible." ) # Add any additional fields to the output converter here - self.output_converter = packets.PacketConverter(joined_typespec, registry=self.registry, include_source_info=False) + self.output_converter = packets.PacketConverter( + joined_typespec, registry=self.registry, include_source_info=False + ) - def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if self._cache_computed: logger.info(f"Returning cached outputs for {self}") @@ -240,8 +269,10 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: raise ValueError( "CachedKernelWrapper has no tag keys defined, cannot return PolarsStream" ) - source_info_sig = ':'.join(self.source_info) - return PolarsStream(self.df, tag_keys=self.tag_keys, packet_keys=self.packet_keys) + source_info_sig = ":".join(self.source_info) + return PolarsStream( + self.df, tag_keys=self.tag_keys, packet_keys=self.packet_keys + ) else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.packet_keys) @@ -268,7 +299,9 @@ def post_call(self, tag: Tag, packet: Packet) -> None: # If an entry with same tag and packet already exists in the output store, # it will not be added again, thus avoiding duplicates. merged_info = {**tag, **packet.get_composite()} - output_table = self.output_converter.from_python_packet_to_arrow_table(merged_info) + output_table = self.output_converter.from_python_packet_to_arrow_table( + merged_info + ) # TODO: revisit this logic output_id = self.arrow_hasher.hash_table(output_table, prefix_hasher_id=True) if not self.output_store.get_record(self.source_info, output_id): @@ -285,7 +318,6 @@ def output_iterator_completion_hook(self) -> None: logger.info(f"Results cached for {self}") self._cache_computed = True - @property def lazy_df(self) -> pl.LazyFrame | None: return self.output_store.get_all_records_as_polars(self.source_info) @@ -424,11 +456,13 @@ def registry(self, registry: SemanticTypeRegistry | None = None): self.update_cached_values() def update_cached_values(self) -> None: - self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod, prefix_hasher_id=True) - self.input_typespec, self.output_typespec = self.function_pod.get_function_typespecs() + self.function_pod_hash = self.object_hasher.hash_to_hex( + self.function_pod, prefix_hasher_id=True + ) + self.input_typespec, self.output_typespec = ( + self.function_pod.get_function_typespecs() + ) self.tag_keys, self.output_keys = self.keys(trigger_run=False) - - if self.tag_keys is None or self.output_keys is None: raise ValueError( @@ -445,15 +479,26 @@ def update_cached_values(self) -> None: self.function_pod.get_function_typespecs() ) - self.input_converter = packets.PacketConverter(self.input_typespec, self.registry, include_source_info=False) - self.output_converter = packets.PacketConverter(self.output_typespec, self.registry, include_source_info=True) + self.input_converter = packets.PacketConverter( + self.input_typespec, self.registry, include_source_info=False + ) + self.output_converter = packets.PacketConverter( + self.output_typespec, self.registry, include_source_info=True + ) - input_packet_source_typespec = {f'_source_info_{k}': str for k in self.input_typespec} + input_packet_source_typespec = { + f"_source_info_{k}": str for k in self.input_typespec + } # prepare typespec for tag record: __packet_key, tag, input packet source_info, - tag_record_typespec = {"__packet_key": str, **self.tag_typespec, **input_packet_source_typespec} - self.tag_record_converter = packets.PacketConverter(tag_record_typespec, self.registry, include_source_info=False) - + tag_record_typespec = { + "__packet_key": str, + **self.tag_typespec, + **input_packet_source_typespec, + } + self.tag_record_converter = packets.PacketConverter( + tag_record_typespec, self.registry, include_source_info=False + ) def reset_cache(self): self._cache_computed = False @@ -472,14 +517,19 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: if lazy_df is not None: if self.tag_keys is None: raise ValueError("Tag keys are not set, cannot return PolarsStream") - return PolarsStream(lazy_df.collect(), self.tag_keys, packet_keys=self.output_keys) + return PolarsStream( + lazy_df.collect(), self.tag_keys, packet_keys=self.output_keys + ) else: return EmptyStream(tag_keys=self.tag_keys, packet_keys=self.output_keys) logger.info(f"Computing and caching outputs for {self}") return super().forward(*streams, **kwargs) def get_packet_key(self, packet: Packet) -> str: - return self.arrow_hasher.hash_table(self.input_converter.from_python_packet_to_arrow_table(packet), prefix_hasher_id=True) + return self.arrow_hasher.hash_table( + self.input_converter.from_python_packet_to_arrow_table(packet), + prefix_hasher_id=True, + ) @property def source_info(self): @@ -493,18 +543,24 @@ def add_pipeline_record(self, tag: Tag, packet: Packet) -> Tag: Record the tag for the packet in the record store. This is used to keep track of the tags associated with memoized packets. """ - return self._add_pipeline_record_with_packet_key(tag, self.get_packet_key(packet), packet.source_info) + return self._add_pipeline_record_with_packet_key( + tag, self.get_packet_key(packet), packet.source_info + ) - def _add_pipeline_record_with_packet_key(self, tag: Tag, packet_key: str, packet_source_info: dict[str, str | None]) -> Tag: + def _add_pipeline_record_with_packet_key( + self, tag: Tag, packet_key: str, packet_source_info: dict[str, str | None] + ) -> Tag: if self.tag_store is None: raise ValueError("Recording of tag requires tag_store but none provided") combined_info = dict(tag) # ensure we don't modify the original tag combined_info["__packet_key"] = packet_key for k, v in packet_source_info.items(): - combined_info[f'_source_info_{k}'] = v + combined_info[f"_source_info_{k}"] = v - table = self.tag_record_converter.from_python_packet_to_arrow_table(combined_info) + table = self.tag_record_converter.from_python_packet_to_arrow_table( + combined_info + ) entry_hash = self.arrow_hasher.hash_table(table, prefix_hasher_id=True) @@ -553,7 +609,9 @@ def memoize( Returns the memoized packet. """ logger.debug("Memoizing packet") - return self._memoize_with_packet_key(self.get_packet_key(packet), output_packet.get_composite()) + return self._memoize_with_packet_key( + self.get_packet_key(packet), output_packet.get_composite() + ) def _memoize_with_packet_key( self, packet_key: str, output_packet: PacketLike @@ -581,7 +639,6 @@ def _memoize_with_packet_key( # attach provenance information return Packet(packet) - def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: packet_key = "" if ( @@ -609,8 +666,11 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: if output_packet is not None and not self.skip_memoization: # output packet may be modified by the memoization process # e.g. if the output is a file, the path may be changed - # add source info to the output packet - source_info = {k: '-'.join(self.source_info) + "-" + packet_key for k in output_packet.source_info} + # add source info to the output packet + source_info = { + k: "-".join(self.source_info) + "-" + packet_key + for k in output_packet.source_info + } # TODO: fix and make this not access protected field directly output_packet.source_info = source_info output_packet = self._memoize_with_packet_key(packet_key, output_packet) # type: ignore @@ -624,7 +684,9 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: # result was successfully computed -- save the tag if not self.skip_tag_record and self.tag_store is not None: - self._add_pipeline_record_with_packet_key(tag, packet_key, packet.source_info) + self._add_pipeline_record_with_packet_key( + tag, packet_key, packet.source_info + ) return tag, output_packet @@ -639,7 +701,9 @@ def get_all_tags(self, with_packet_id: bool = False) -> pl.LazyFrame | None: return data.drop("__packet_key") if data is not None else None return data - def get_all_entries_with_tags(self, keep_hidden_fields: bool = False) -> pl.LazyFrame | None: + def get_all_entries_with_tags( + self, keep_hidden_fields: bool = False + ) -> pl.LazyFrame | None: """ Retrieve all entries from the tag store with their associated tags. Returns a DataFrame with columns for tag and packet key. diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 394a454..7e04d96 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -95,7 +95,12 @@ def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node tag_store=self.pipeline_store, store_path_prefix=self.name, ) - return KernelNode(kernel, input_nodes, output_store=self.pipeline_store, store_path_prefix=self.name) + return KernelNode( + kernel, + input_nodes, + output_store=self.pipeline_store, + store_path_prefix=self.name, + ) def compile(self): import networkx as nx diff --git a/src/orcapod/stores/arrow_data_stores.py b/src/orcapod/stores/arrow_data_stores.py index 2608cbc..2897ead 100644 --- a/src/orcapod/stores/arrow_data_stores.py +++ b/src/orcapod/stores/arrow_data_stores.py @@ -24,7 +24,11 @@ def __init__(self): logger.info("Initialized MockArrowDataStore") def add_record( - self, source_pathh: tuple[str, ...], source_id: str, entry_id: str, arrow_data: pa.Table + self, + source_pathh: tuple[str, ...], + source_id: str, + entry_id: str, + arrow_data: pa.Table, ) -> pa.Table: """Add a record to the mock store.""" return arrow_data @@ -35,7 +39,9 @@ def get_record( """Get a specific record.""" return None - def get_all_records(self, source_path: tuple[str, ...], source_id: str) -> pa.Table | None: + def get_all_records( + self, source_path: tuple[str, ...], source_id: str + ) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" return None @@ -47,7 +53,7 @@ def get_all_records_as_polars( def get_records_by_ids( self, - source_path: tuple[str,...], + source_path: tuple[str, ...], source_id: str, entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, @@ -77,7 +83,7 @@ def get_records_by_ids( def get_records_by_ids_as_polars( self, - source_path: tuple[str,...], + source_path: tuple[str, ...], source_id: str, entry_ids: list[str] | pl.Series | pa.Array, add_entry_id_column: bool | str = False, diff --git a/src/orcapod/stores/delta_table_arrow_data_store.py b/src/orcapod/stores/delta_table_arrow_data_store.py index d4fcaf3..c05dea9 100644 --- a/src/orcapod/stores/delta_table_arrow_data_store.py +++ b/src/orcapod/stores/delta_table_arrow_data_store.py @@ -13,7 +13,7 @@ class DeltaTableArrowDataStore: """ Delta Table-based Arrow data store with flexible hierarchical path support. - + Uses tuple-based source paths for robust parameter handling: - ("source_name", "source_id") -> source_name/source_id/ - ("org", "project", "dataset") -> org/project/dataset/ @@ -21,11 +21,11 @@ class DeltaTableArrowDataStore: """ def __init__( - self, - base_path: str | Path, + self, + base_path: str | Path, duplicate_entry_behavior: str = "error", create_base_path: bool = True, - max_hierarchy_depth: int = 10 + max_hierarchy_depth: int = 10, ): """ Initialize the DeltaTableArrowDataStore. @@ -41,19 +41,21 @@ def __init__( # Validate duplicate behavior if duplicate_entry_behavior not in ["error", "overwrite"]: raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - + self.duplicate_entry_behavior = duplicate_entry_behavior self.base_path = Path(base_path) self.max_hierarchy_depth = max_hierarchy_depth - + if create_base_path: self.base_path.mkdir(parents=True, exist_ok=True) elif not self.base_path.exists(): - raise ValueError(f"Base path {self.base_path} does not exist and create_base_path=False") - + raise ValueError( + f"Base path {self.base_path} does not exist and create_base_path=False" + ) + # Cache for Delta tables to avoid repeated initialization self._delta_table_cache: dict[str, DeltaTable] = {} - + logger.info( f"Initialized DeltaTableArrowDataStore at {self.base_path} " f"with duplicate_entry_behavior='{duplicate_entry_behavior}'" @@ -62,28 +64,34 @@ def __init__( def _validate_source_path(self, source_path: tuple[str, ...]) -> None: """ Validate source path components. - + Args: source_path: Tuple of path components - + Raises: ValueError: If path is invalid """ if not source_path: raise ValueError("Source path cannot be empty") - + if len(source_path) > self.max_hierarchy_depth: - raise ValueError(f"Source path depth {len(source_path)} exceeds maximum {self.max_hierarchy_depth}") - + raise ValueError( + f"Source path depth {len(source_path)} exceeds maximum {self.max_hierarchy_depth}" + ) + # Validate path components for i, component in enumerate(source_path): if not component or not isinstance(component, str): - raise ValueError(f"Source path component {i} is invalid: {repr(component)}") - + raise ValueError( + f"Source path component {i} is invalid: {repr(component)}" + ) + # Check for filesystem-unsafe characters - unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\0'] + unsafe_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\0"] if any(char in component for char in unsafe_chars): - raise ValueError(f"Source path component contains invalid characters: {repr(component)}") + raise ValueError( + f"Source path component contains invalid characters: {repr(component)}" + ) def _get_source_key(self, source_path: tuple[str, ...]) -> str: """Generate cache key for source storage.""" @@ -115,13 +123,11 @@ def _remove_entry_id_column(self, arrow_data: pa.Table) -> pa.Table: return arrow_data def _handle_entry_id_column( - self, - arrow_data: pa.Table, - add_entry_id_column: bool | str = False + self, arrow_data: pa.Table, add_entry_id_column: bool | str = False ) -> pa.Table: """ Handle entry_id column based on add_entry_id_column parameter. - + Args: arrow_data: Arrow table with __entry_id column add_entry_id_column: Control entry ID column inclusion: @@ -167,16 +173,16 @@ def add_record( ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) source_key = self._get_source_key(source_path) - + # Ensure directory exists table_path.mkdir(parents=True, exist_ok=True) - + # Add entry_id column to the data data_with_entry_id = self._ensure_entry_id_column(arrow_data, entry_id) - + # Check for existing entry if needed if not ignore_duplicate and self.duplicate_entry_behavior == "error": existing_record = self.get_record(source_path, entry_id) @@ -185,41 +191,36 @@ def add_record( f"Entry '{entry_id}' already exists in {'/'.join(source_path)}. " f"Use duplicate_entry_behavior='overwrite' to allow updates." ) - + try: # Try to load existing table delta_table = DeltaTable(str(table_path)) - + if self.duplicate_entry_behavior == "overwrite": # Delete existing record if it exists, then append new one try: # First, delete existing record with this entry_id delta_table.delete(f"__entry_id = '{entry_id}'") - logger.debug(f"Deleted existing record {entry_id} from {source_key}") + logger.debug( + f"Deleted existing record {entry_id} from {source_key}" + ) except Exception as e: # If delete fails (e.g., record doesn't exist), that's fine logger.debug(f"No existing record to delete for {entry_id}: {e}") - + # Append new record write_deltalake( - str(table_path), - data_with_entry_id, - mode="append", - schema_mode="merge" + str(table_path), data_with_entry_id, mode="append", schema_mode="merge" ) - + except TableNotFoundError: # Table doesn't exist, create it - write_deltalake( - str(table_path), - data_with_entry_id, - mode="overwrite" - ) + write_deltalake(str(table_path), data_with_entry_id, mode="overwrite") logger.debug(f"Created new Delta table for {source_key}") - + # Update cache self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - + logger.debug(f"Added record {entry_id} to {source_key}") return arrow_data @@ -228,36 +229,36 @@ def get_record( ) -> pa.Table | None: """ Get a specific record by entry_id. - + Args: source_path: Tuple of path components entry_id: Unique identifier for the record - + Returns: Arrow table for the record, or None if not found """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) - + try: delta_table = DeltaTable(str(table_path)) - + # Query for the specific entry_id - result = delta_table.to_pyarrow_table( - filter=f"__entry_id = '{entry_id}'" - ) - + result = delta_table.to_pyarrow_table(filter=f"__entry_id = '{entry_id}'") + if len(result) == 0: return None - + # Remove the __entry_id column before returning return self._remove_entry_id_column(result) - + except TableNotFoundError: return None except Exception as e: - logger.error(f"Error getting record {entry_id} from {'/'.join(source_path)}: {e}") + logger.error( + f"Error getting record {entry_id} from {'/'.join(source_path)}: {e}" + ) return None def get_all_records( @@ -265,31 +266,31 @@ def get_all_records( ) -> pa.Table | None: """ Retrieve all records for a given source path as a single table. - + Args: source_path: Tuple of path components add_entry_id_column: Control entry ID column inclusion: - False: Don't include entry ID column (default) - True: Include entry ID column as "__entry_id" - str: Include entry ID column with custom name - + Returns: Arrow table containing all records, or None if no records found """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) - + try: delta_table = DeltaTable(str(table_path)) result = delta_table.to_pyarrow_table() - + if len(result) == 0: return None - + # Handle entry_id column based on parameter return self._handle_entry_id_column(result, add_entry_id_column) - + except TableNotFoundError: return None except Exception as e: @@ -301,10 +302,10 @@ def get_all_records_as_polars( ) -> pl.LazyFrame | None: """ Retrieve all records for a given source path as a single Polars LazyFrame. - + Args: source_path: Tuple of path components - + Returns: Polars LazyFrame containing all records, or None if no records found """ @@ -333,7 +334,7 @@ def get_records_by_ids( Arrow table containing all found records, or None if no records found """ self._validate_source_path(source_path) - + # Convert input to list of strings for consistency if isinstance(entry_ids, list): if not entry_ids: @@ -353,39 +354,41 @@ def get_records_by_ids( ) table_path = self._get_table_path(source_path) - + try: delta_table = DeltaTable(str(table_path)) - + # Create filter for the entry IDs - escape single quotes in IDs escaped_ids = [id_.replace("'", "''") for id_ in entry_ids_list] id_filter = " OR ".join([f"__entry_id = '{id_}'" for id_ in escaped_ids]) - + result = delta_table.to_pyarrow_table(filter=id_filter) - + if len(result) == 0: return None - + if preserve_input_order: # Need to reorder results and add nulls for missing entries import pandas as pd - + df = result.to_pandas() - df = df.set_index('__entry_id') - + df = df.set_index("__entry_id") + # Create a DataFrame with the desired order, filling missing with NaN ordered_df = df.reindex(entry_ids_list) - + # Convert back to Arrow result = pa.Table.from_pandas(ordered_df.reset_index()) - + # Handle entry_id column based on parameter return self._handle_entry_id_column(result, add_entry_id_column) - + except TableNotFoundError: return None except Exception as e: - logger.error(f"Error getting records by IDs from {'/'.join(source_path)}: {e}") + logger.error( + f"Error getting records by IDs from {'/'.join(source_path)}: {e}" + ) return None def get_records_by_ids_as_polars( @@ -397,13 +400,13 @@ def get_records_by_ids_as_polars( ) -> pl.LazyFrame | None: """ Retrieve records by entry IDs as a single Polars LazyFrame. - + Args: source_path: Tuple of path components entry_ids: Entry IDs to retrieve add_entry_id_column: Control entry ID column inclusion preserve_input_order: If True, return results in input order with nulls for missing - + Returns: Polars LazyFrame containing all found records, or None if no records found """ @@ -421,20 +424,20 @@ def get_records_by_ids_as_polars( def list_sources(self) -> list[tuple[str, ...]]: """ List all available source paths. - + Returns: List of source path tuples """ sources = [] - + def _scan_directory(current_path: Path, path_components: tuple[str, ...]): """Recursively scan for Delta tables.""" for item in current_path.iterdir(): if not item.is_dir(): continue - + new_path_components = path_components + (item.name,) - + # Check if this directory contains a Delta table try: DeltaTable(str(item)) @@ -443,40 +446,41 @@ def _scan_directory(current_path: Path, path_components: tuple[str, ...]): # Not a Delta table, continue scanning subdirectories if len(new_path_components) < self.max_hierarchy_depth: _scan_directory(item, new_path_components) - + _scan_directory(self.base_path, ()) return sources def delete_source(self, source_path: tuple[str, ...]) -> bool: """ Delete an entire source (all records for a source path). - + Args: source_path: Tuple of path components - + Returns: True if source was deleted, False if it didn't exist """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) source_key = self._get_source_key(source_path) - + if not table_path.exists(): return False - + try: # Remove from cache if source_key in self._delta_table_cache: del self._delta_table_cache[source_key] - + # Remove directory import shutil + shutil.rmtree(table_path) - + logger.info(f"Deleted source {source_key}") return True - + except Exception as e: logger.error(f"Error deleting source {source_key}: {e}") return False @@ -484,64 +488,68 @@ def delete_source(self, source_path: tuple[str, ...]) -> bool: def delete_record(self, source_path: tuple[str, ...], entry_id: str) -> bool: """ Delete a specific record. - + Args: source_path: Tuple of path components entry_id: ID of the record to delete - + Returns: True if record was deleted, False if it didn't exist """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) - + try: delta_table = DeltaTable(str(table_path)) - + # Check if record exists escaped_entry_id = entry_id.replace("'", "''") - existing = delta_table.to_pyarrow_table(filter=f"__entry_id = '{escaped_entry_id}'") + existing = delta_table.to_pyarrow_table( + filter=f"__entry_id = '{escaped_entry_id}'" + ) if len(existing) == 0: return False - + # Delete the record delta_table.delete(f"__entry_id = '{escaped_entry_id}'") - + # Update cache source_key = self._get_source_key(source_path) self._delta_table_cache[source_key] = delta_table - + logger.debug(f"Deleted record {entry_id} from {'/'.join(source_path)}") return True - + except TableNotFoundError: return False except Exception as e: - logger.error(f"Error deleting record {entry_id} from {'/'.join(source_path)}: {e}") + logger.error( + f"Error deleting record {entry_id} from {'/'.join(source_path)}: {e}" + ) return False def get_table_info(self, source_path: tuple[str, ...]) -> dict[str, Any] | None: """ Get metadata information about a Delta table. - + Args: source_path: Tuple of path components - + Returns: Dictionary with table metadata, or None if table doesn't exist """ self._validate_source_path(source_path) - + table_path = self._get_table_path(source_path) - + try: delta_table = DeltaTable(str(table_path)) - + # Get basic info schema = delta_table.schema() history = delta_table.history() - + return { "path": str(table_path), "source_path": source_path, @@ -551,9 +559,9 @@ def get_table_info(self, source_path: tuple[str, ...]) -> dict[str, Any] | None: "history_length": len(history), "latest_commit": history[0] if history else None, } - + except TableNotFoundError: return None except Exception as e: logger.error(f"Error getting table info for {'/'.join(source_path)}: {e}") - return None \ No newline at end of file + return None diff --git a/src/orcapod/stores/optimized_memory_store.py b/src/orcapod/stores/optimized_memory_store.py index ff962e9..1859113 100644 --- a/src/orcapod/stores/optimized_memory_store.py +++ b/src/orcapod/stores/optimized_memory_store.py @@ -11,7 +11,7 @@ class ArrowBatchedPolarsDataStore: """ Arrow-batched Polars data store that minimizes Arrow<->Polars conversions. - + Key optimizations: 1. Keep data in Arrow format during batching 2. Only convert to Polars when consolidating or querying @@ -32,22 +32,22 @@ def __init__(self, duplicate_entry_behavior: str = "error", batch_size: int = 10 """ if duplicate_entry_behavior not in ["error", "overwrite"]: raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - + self.duplicate_entry_behavior = duplicate_entry_behavior self.batch_size = batch_size # Arrow batch buffer: {source_key: [(entry_id, arrow_table), ...]} self._arrow_batches: Dict[str, List[Tuple[str, pa.Table]]] = defaultdict(list) - + # Consolidated Polars store: {source_key: polars_dataframe} self._polars_store: Dict[str, pl.DataFrame] = {} - + # Entry ID index for fast lookups: {source_key: set[entry_ids]} self._entry_index: Dict[str, set] = defaultdict(set) - + # Schema cache self._schema_cache: Dict[str, pa.Schema] = {} - + logger.info( f"Initialized ArrowBatchedPolarsDataStore with " f"duplicate_entry_behavior='{duplicate_entry_behavior}', batch_size={batch_size}" @@ -61,7 +61,7 @@ def _add_entry_id_to_arrow_table(self, table: pa.Table, entry_id: str) -> pa.Tab """Add entry_id column to Arrow table efficiently.""" # Create entry_id array with the same length as the table entry_id_array = pa.array([entry_id] * len(table), type=pa.string()) - + # Add column at the beginning for consistent ordering return table.add_column(0, "__entry_id", entry_id_array) @@ -69,36 +69,40 @@ def _consolidate_arrow_batch(self, source_key: str) -> None: """Consolidate Arrow batch into Polars DataFrame.""" if source_key not in self._arrow_batches or not self._arrow_batches[source_key]: return - - logger.debug(f"Consolidating {len(self._arrow_batches[source_key])} Arrow tables for {source_key}") - + + logger.debug( + f"Consolidating {len(self._arrow_batches[source_key])} Arrow tables for {source_key}" + ) + # Prepare all Arrow tables with entry_id columns arrow_tables_with_id = [] - + for entry_id, arrow_table in self._arrow_batches[source_key]: table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) arrow_tables_with_id.append(table_with_id) - + # Concatenate all Arrow tables at once (very fast) if len(arrow_tables_with_id) == 1: consolidated_arrow = arrow_tables_with_id[0] else: consolidated_arrow = pa.concat_tables(arrow_tables_with_id) - + # Single conversion to Polars new_polars_df = cast(pl.DataFrame, pl.from_arrow(consolidated_arrow)) - + # Combine with existing Polars DataFrame if it exists if source_key in self._polars_store: existing_df = self._polars_store[source_key] self._polars_store[source_key] = pl.concat([existing_df, new_polars_df]) else: self._polars_store[source_key] = new_polars_df - + # Clear the Arrow batch self._arrow_batches[source_key].clear() - - logger.debug(f"Consolidated to Polars DataFrame with {len(self._polars_store[source_key])} total rows") + + logger.debug( + f"Consolidated to Polars DataFrame with {len(self._polars_store[source_key])} total rows" + ) def _force_consolidation(self, source_key: str) -> None: """Force consolidation of Arrow batches.""" @@ -119,7 +123,7 @@ def add_record( ) -> pa.Table: """ Add a record to the store using Arrow batching. - + This is the fastest path - no conversions, just Arrow table storage. """ source_key = self._get_source_key(source_name, source_id) @@ -135,15 +139,16 @@ def add_record( # Handle overwrite: remove from both Arrow batch and Polars store # Remove from Arrow batch self._arrow_batches[source_key] = [ - (eid, table) for eid, table in self._arrow_batches[source_key] + (eid, table) + for eid, table in self._arrow_batches[source_key] if eid != entry_id ] - + # Remove from Polars store if it exists if source_key in self._polars_store: - self._polars_store[source_key] = self._polars_store[source_key].filter( - pl.col("__entry_id") != entry_id - ) + self._polars_store[source_key] = self._polars_store[ + source_key + ].filter(pl.col("__entry_id") != entry_id) # Schema validation (cached) if source_key in self._schema_cache: @@ -159,7 +164,7 @@ def add_record( # Add to Arrow batch (no conversion yet!) self._arrow_batches[source_key].append((entry_id, arrow_data)) self._entry_index[source_key].add(entry_id) - + # Consolidate if batch is full if len(self._arrow_batches[source_key]) >= self.batch_size: self._consolidate_arrow_batch(source_key) @@ -172,16 +177,16 @@ def get_record( ) -> pa.Table | None: """Get a specific record with optimized lookup.""" source_key = self._get_source_key(source_name, source_id) - + # Quick existence check if entry_id not in self._entry_index[source_key]: return None - + # Check Arrow batch first (most recent data) for batch_entry_id, arrow_table in self._arrow_batches[source_key]: if batch_entry_id == entry_id: return arrow_table - + # Check consolidated Polars store df = self._get_consolidated_dataframe(source_key) if df is None: @@ -189,7 +194,7 @@ def get_record( # Filter and convert back to Arrow filtered_df = df.filter(pl.col("__entry_id") == entry_id).drop("__entry_id") - + if filtered_df.height == 0: return None @@ -200,7 +205,7 @@ def get_all_records( ) -> pa.Table | None: """Retrieve all records as a single Arrow table.""" source_key = self._get_source_key(source_name, source_id) - + # Force consolidation to include all data df = self._get_consolidated_dataframe(source_key) if df is None or df.height == 0: @@ -223,7 +228,7 @@ def get_all_records_as_polars( ) -> pl.LazyFrame | None: """Retrieve all records as a Polars LazyFrame.""" source_key = self._get_source_key(source_name, source_id) - + df = self._get_consolidated_dataframe(source_key) if df is None or df.height == 0: return None @@ -256,20 +261,21 @@ def get_records_by_ids( raise TypeError(f"entry_ids must be list[str], pl.Series, or pa.Array") source_key = self._get_source_key(source_name, source_id) - + # Quick filter using index existing_entries = [ - entry_id for entry_id in entry_ids_list + entry_id + for entry_id in entry_ids_list if entry_id in self._entry_index[source_key] ] - + if not existing_entries and not preserve_input_order: return None # Collect from Arrow batch first batch_tables = [] found_in_batch = set() - + for entry_id, arrow_table in self._arrow_batches[source_key]: if entry_id in entry_ids_list: table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) @@ -278,7 +284,7 @@ def get_records_by_ids( # Get remaining from consolidated store remaining_ids = [eid for eid in existing_entries if eid not in found_in_batch] - + consolidated_tables = [] if remaining_ids: df = self._get_consolidated_dataframe(source_key) @@ -288,13 +294,13 @@ def get_records_by_ids( result_df = ordered_df.join(df, on="__entry_id", how="left") else: result_df = df.filter(pl.col("__entry_id").is_in(remaining_ids)) - + if result_df.height > 0: consolidated_tables.append(result_df.to_arrow()) # Combine all results all_tables = batch_tables + consolidated_tables - + if not all_tables: return None @@ -309,7 +315,9 @@ def get_records_by_ids( # Remove __entry_id column column_names = result_table.column_names if "__entry_id" in column_names: - indices = [i for i, name in enumerate(column_names) if name != "__entry_id"] + indices = [ + i for i, name in enumerate(column_names) if name != "__entry_id" + ] result_table = result_table.select(indices) elif isinstance(add_entry_id_column, str): # Rename __entry_id column @@ -337,7 +345,7 @@ def get_records_by_ids_as_polars( if arrow_result is None: return None - + pl_result = cast(pl.DataFrame, pl.from_arrow(arrow_result)) return pl_result.lazy() @@ -370,7 +378,7 @@ def force_consolidation(self) -> None: def clear_source(self, source_name: str, source_id: str) -> None: """Clear all data for a source.""" source_key = self._get_source_key(source_name, source_id) - + if source_key in self._arrow_batches: del self._arrow_batches[source_key] if source_key in self._polars_store: @@ -379,7 +387,7 @@ def clear_source(self, source_name: str, source_id: str) -> None: del self._entry_index[source_key] if source_key in self._schema_cache: del self._schema_cache[source_key] - + logger.debug(f"Cleared source {source_key}") def clear_all(self) -> None: @@ -394,25 +402,29 @@ def get_stats(self) -> dict[str, Any]: """Get comprehensive statistics.""" total_records = sum(len(entries) for entries in self._entry_index.values()) total_batched = sum(len(batch) for batch in self._arrow_batches.values()) - total_consolidated = sum( - len(df) for df in self._polars_store.values() - ) if self._polars_store else 0 - + total_consolidated = ( + sum(len(df) for df in self._polars_store.values()) + if self._polars_store + else 0 + ) + source_stats = [] for source_key in self._entry_index.keys(): record_count = len(self._entry_index[source_key]) batched_count = len(self._arrow_batches.get(source_key, [])) consolidated_count = 0 - + if source_key in self._polars_store: consolidated_count = len(self._polars_store[source_key]) - - source_stats.append({ - "source_key": source_key, - "total_records": record_count, - "batched_records": batched_count, - "consolidated_records": consolidated_count, - }) + + source_stats.append( + { + "source_key": source_key, + "total_records": record_count, + "batched_records": batched_count, + "consolidated_records": consolidated_count, + } + ) return { "total_records": total_records, @@ -430,4 +442,4 @@ def optimize_for_reads(self) -> None: self.force_consolidation() # Clear Arrow batches to save memory self._arrow_batches.clear() - logger.info("Optimization complete") \ No newline at end of file + logger.info("Optimization complete") diff --git a/src/orcapod/stores/transfer_data_store.py b/src/orcapod/stores/transfer_data_store.py index 0c8e215..9e393e0 100644 --- a/src/orcapod/stores/transfer_data_store.py +++ b/src/orcapod/stores/transfer_data_store.py @@ -14,7 +14,9 @@ def __init__(self, source_store: DataStore, target_store: DataStore) -> None: self.source_store = source_store self.target_store = target_store - def transfer(self, function_name: str, content_hash: str, packet: PacketLike) -> PacketLike: + def transfer( + self, function_name: str, content_hash: str, packet: PacketLike + ) -> PacketLike: """ Transfer a memoized packet from the source store to the target store. """ diff --git a/src/orcapod/stores/types.py b/src/orcapod/stores/types.py index c588856..da7e492 100644 --- a/src/orcapod/stores/types.py +++ b/src/orcapod/stores/types.py @@ -48,15 +48,15 @@ def add_record( ) -> pa.Table: ... def get_record( - self, source_path: tuple[str,...], entry_id: str + self, source_path: tuple[str, ...], entry_id: str ) -> pa.Table | None: ... - def get_all_records(self, source_path: tuple[str,...]) -> pa.Table | None: + def get_all_records(self, source_path: tuple[str, ...]) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" ... def get_all_records_as_polars( - self, source_path: tuple[str,...] + self, source_path: tuple[str, ...] ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a single Polars DataFrame.""" ... diff --git a/src/orcapod/types/core.py b/src/orcapod/types/core.py index dd02141..12448f8 100644 --- a/src/orcapod/types/core.py +++ b/src/orcapod/types/core.py @@ -5,7 +5,6 @@ from collections.abc import Collection, Mapping - DataType: TypeAlias = type TypeSpec: TypeAlias = Mapping[ @@ -34,8 +33,9 @@ # Extended data values that can be stored in packets # Either the original PathSet or one of our supported simple data types -DataValue: TypeAlias = PathSet | SupportedNativePythonData | None | Collection["DataValue"] - +DataValue: TypeAlias = ( + PathSet | SupportedNativePythonData | None | Collection["DataValue"] +) class PodFunction(Protocol): diff --git a/src/orcapod/types/packet_converter.py b/src/orcapod/types/packet_converter.py index 0a8389d..e486222 100644 --- a/src/orcapod/types/packet_converter.py +++ b/src/orcapod/types/packet_converter.py @@ -1,6 +1,11 @@ from orcapod.types.core import TypeSpec, TypeHandler from orcapod.types.packets import Packet, PacketLike -from orcapod.types.semantic_type_registry import SemanticTypeRegistry, TypeInfo, get_metadata_from_schema, arrow_to_dicts +from orcapod.types.semantic_type_registry import ( + SemanticTypeRegistry, + TypeInfo, + get_metadata_from_schema, + arrow_to_dicts, +) from typing import Any from collections.abc import Mapping, Sequence import pyarrow as pa @@ -10,7 +15,9 @@ def is_packet_supported( - python_type_info: TypeSpec, registry: SemanticTypeRegistry, type_lut: dict | None = None + python_type_info: TypeSpec, + registry: SemanticTypeRegistry, + type_lut: dict | None = None, ) -> bool: """Check if all types in the packet are supported by the registry or known to the default lut.""" if type_lut is None: @@ -21,7 +28,6 @@ def is_packet_supported( ) - class PacketConverter: def __init__(self, python_type_spec: TypeSpec, registry: SemanticTypeRegistry): self.python_type_spec = python_type_spec @@ -174,4 +180,3 @@ def from_arrow_table( return storage_packets return [Packet(self._from_storage_packet(packet)) for packet in storage_packets] - diff --git a/src/orcapod/types/packets.py b/src/orcapod/types/packets.py index a6621ee..47df081 100644 --- a/src/orcapod/types/packets.py +++ b/src/orcapod/types/packets.py @@ -12,16 +12,17 @@ class Packet(dict[str, DataValue]): def __init__( - self, + self, obj: PacketLike | None = None, - typespec: TypeSpec | None = None, - source_info: dict[str, str|None] | None = None + typespec: TypeSpec | None = None, + source_info: dict[str, str | None] | None = None, ): if obj is None: obj = {} super().__init__(obj) if typespec is None: from orcapod.types.typespec_utils import get_typespec_from_dict + typespec = get_typespec_from_dict(self) self._typespec = typespec if source_info is None: @@ -36,18 +37,22 @@ def typespec(self) -> TypeSpec: @property def source_info(self) -> dict[str, str | None]: return {key: self._source_info.get(key, None) for key in self.keys()} - + @source_info.setter def source_info(self, source_info: Mapping[str, str | None]): - self._source_info = {key: value for key, value in source_info.items() if value is not None} + self._source_info = { + key: value for key, value in source_info.items() if value is not None + } def get_composite(self) -> PacketLike: composite = self.copy() for k, v in self.source_info.items(): composite[f"_source_info_{k}"] = v return composite - - def map_keys(self, mapping: Mapping[str, str], drop_unmapped: bool=False) -> 'Packet': + + def map_keys( + self, mapping: Mapping[str, str], drop_unmapped: bool = False + ) -> "Packet": """ Map the keys of the packet using the provided mapping. @@ -58,23 +63,25 @@ def map_keys(self, mapping: Mapping[str, str], drop_unmapped: bool=False) -> 'Pa A new Packet with keys mapped according to the provided mapping. """ if drop_unmapped: - new_content = { - v: self[k] for k, v in mapping.items() if k in self - } + new_content = {v: self[k] for k, v in mapping.items() if k in self} new_typespec = { v: self.typespec[k] for k, v in mapping.items() if k in self.typespec } new_source_info = { - v: self.source_info[k] for k, v in mapping.items() if k in self.source_info + v: self.source_info[k] + for k, v in mapping.items() + if k in self.source_info } else: new_content = {mapping.get(k, k): v for k, v in self.items()} new_typespec = {mapping.get(k, k): v for k, v in self.typespec.items()} - new_source_info = {mapping.get(k, k): v for k, v in self.source_info.items()} + new_source_info = { + mapping.get(k, k): v for k, v in self.source_info.items() + } return Packet(new_content, typespec=new_typespec, source_info=new_source_info) - - def join(self, other: 'Packet') -> 'Packet': + + def join(self, other: "Packet") -> "Packet": """ Join another packet to this one, merging their keys and values. @@ -86,13 +93,15 @@ def join(self, other: 'Packet') -> 'Packet': """ # make sure there is no key collision if not set(self.keys()).isdisjoint(other.keys()): - raise ValueError(f"Key collision detected: packets {self} and {other} have overlapping keys" - " and cannot be joined without losing information.") + raise ValueError( + f"Key collision detected: packets {self} and {other} have overlapping keys" + " and cannot be joined without losing information." + ) new_content = {**self, **other} new_typespec = {**self.typespec, **other.typespec} new_source_info = {**self.source_info, **other.source_info} - + return Packet(new_content, typespec=new_typespec, source_info=new_source_info) @@ -103,23 +112,30 @@ def join(self, other: 'Packet') -> 'Packet': class SemanticPacket(dict[str, Any]): """ A packet that conforms to a semantic schema, mapping string keys to values. - + This is used to represent data packets in OrcaPod with semantic types. - + Attributes ---------- keys : str The keys of the packet. values : Any The values corresponding to each key. - + Examples -------- >>> packet = SemanticPacket(name='Alice', age=30) >>> print(packet) {'name': 'Alice', 'age': 30} """ - def __init__(self, *args, semantic_schema: schemas.SemanticSchema | None = None, source_info: dict[str, str|None] | None = None, **kwargs): + + def __init__( + self, + *args, + semantic_schema: schemas.SemanticSchema | None = None, + source_info: dict[str, str | None] | None = None, + **kwargs, + ): super().__init__(*args, **kwargs) self.schema = semantic_schema if source_info is None: @@ -134,7 +150,12 @@ def get_composite(self) -> dict[str, Any]: class PacketConverter: - def __init__(self, typespec: TypeSpec, registry: SemanticTypeRegistry, include_source_info: bool = True): + def __init__( + self, + typespec: TypeSpec, + registry: SemanticTypeRegistry, + include_source_info: bool = True, + ): self.typespec = typespec self.registry = registry @@ -148,8 +169,6 @@ def __init__(self, typespec: TypeSpec, registry: SemanticTypeRegistry, include_s self.semantic_schema, include_source_info=self.include_source_info ) - - self.key_handlers: dict[str, TypeHandler] = {} self.expected_key_set = set(self.typespec.keys()) @@ -178,7 +197,9 @@ def _check_key_consistency(self, keys): raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") - def from_python_packet_to_semantic_packet(self, python_packet: PacketLike) -> SemanticPacket: + def from_python_packet_to_semantic_packet( + self, python_packet: PacketLike + ) -> SemanticPacket: """Convert a Python packet to a semantic packet. Args: @@ -193,22 +214,22 @@ def from_python_packet_to_semantic_packet(self, python_packet: PacketLike) -> Se ValueError: If conversion fails """ # Validate packet keys - semantic_packet = SemanticPacket(python_packet, semantic_schema=self.semantic_schema, source_info=getattr(python_packet, "source_info", None)) + semantic_packet = SemanticPacket( + python_packet, + semantic_schema=self.semantic_schema, + source_info=getattr(python_packet, "source_info", None), + ) self._check_key_consistency(set(semantic_packet.keys())) # convert from storage to Python types for semantic types for key, handler in self.key_handlers.items(): try: - semantic_packet[key] = handler.python_to_storage( - semantic_packet[key] - ) + semantic_packet[key] = handler.python_to_storage(semantic_packet[key]) except Exception as e: raise ValueError(f"Failed to convert value for '{key}': {e}") from e return semantic_packet - - def from_python_packet_to_arrow_table(self, python_packet: PacketLike) -> pa.Table: """Convert a Python packet to an Arrow table. @@ -221,7 +242,9 @@ def from_python_packet_to_arrow_table(self, python_packet: PacketLike) -> pa.Tab semantic_packet = self.from_python_packet_to_semantic_packet(python_packet) return self.from_semantic_packet_to_arrow_table(semantic_packet) - def from_semantic_packet_to_arrow_table(self, semantic_packet: SemanticPacket) -> pa.Table: + def from_semantic_packet_to_arrow_table( + self, semantic_packet: SemanticPacket + ) -> pa.Table: """Convert a semantic packet to an Arrow table. Args: @@ -231,12 +254,15 @@ def from_semantic_packet_to_arrow_table(self, semantic_packet: SemanticPacket) - Arrow table representation of the packet """ if self.include_source_info: - return pa.Table.from_pylist([semantic_packet.get_composite()], schema=self.arrow_schema) + return pa.Table.from_pylist( + [semantic_packet.get_composite()], schema=self.arrow_schema + ) else: return pa.Table.from_pylist([semantic_packet], schema=self.arrow_schema) - - def from_arrow_table_to_semantic_packets(self, arrow_table: pa.Table) -> Collection[SemanticPacket]: + def from_arrow_table_to_semantic_packets( + self, arrow_table: pa.Table + ) -> Collection[SemanticPacket]: """Convert an Arrow table to a semantic packet. Args: @@ -249,18 +275,34 @@ def from_arrow_table_to_semantic_packets(self, arrow_table: pa.Table) -> Collect # schema matches what's expected if not arrow_table.schema.equals(self.arrow_schema): raise ValueError("Arrow table schema does not match expected schema") - + semantic_packets_contents = arrow_table.to_pylist() - + semantic_packets = [] for all_packet_content in semantic_packets_contents: - packet_content = {k: v for k, v in all_packet_content.items() if k in self.expected_key_set} - source_info = {k.removeprefix('_source_info_'): v for k, v in all_packet_content.items() if k.startswith('_source_info_')} - semantic_packets.append(SemanticPacket(packet_content, semantic_schema=self.semantic_schema, source_info=source_info)) + packet_content = { + k: v + for k, v in all_packet_content.items() + if k in self.expected_key_set + } + source_info = { + k.removeprefix("_source_info_"): v + for k, v in all_packet_content.items() + if k.startswith("_source_info_") + } + semantic_packets.append( + SemanticPacket( + packet_content, + semantic_schema=self.semantic_schema, + source_info=source_info, + ) + ) return semantic_packets - def from_semantic_packet_to_python_packet(self, semantic_packet: SemanticPacket) -> Packet: + def from_semantic_packet_to_python_packet( + self, semantic_packet: SemanticPacket + ) -> Packet: """Convert a semantic packet to a Python packet. Args: @@ -270,18 +312,20 @@ def from_semantic_packet_to_python_packet(self, semantic_packet: SemanticPacket) Python packet representation of the semantic packet """ # Validate packet keys - python_packet = Packet(semantic_packet, typespec=self.typespec, source_info=semantic_packet.source_info) + python_packet = Packet( + semantic_packet, + typespec=self.typespec, + source_info=semantic_packet.source_info, + ) packet_keys = set(python_packet.keys()) self._check_key_consistency(packet_keys) for key, handler in self.key_handlers.items(): try: - python_packet[key] = handler.storage_to_python( - python_packet[key] - ) + python_packet[key] = handler.storage_to_python(python_packet[key]) except Exception as e: raise ValueError(f"Failed to convert value for '{key}': {e}") from e - + return python_packet def from_arrow_table_to_python_packets(self, arrow_table: pa.Table) -> list[Packet]: @@ -294,5 +338,6 @@ def from_arrow_table_to_python_packets(self, arrow_table: pa.Table) -> list[Pack List of Python packets converted from the Arrow table """ semantic_packets = self.from_arrow_table_to_semantic_packets(arrow_table) - return [self.from_semantic_packet_to_python_packet(sp) for sp in semantic_packets] - + return [ + self.from_semantic_packet_to_python_packet(sp) for sp in semantic_packets + ] diff --git a/src/orcapod/types/schemas.py b/src/orcapod/types/schemas.py index 19e8a3b..dc2112f 100644 --- a/src/orcapod/types/schemas.py +++ b/src/orcapod/types/schemas.py @@ -1,5 +1,4 @@ - -from orcapod.types import TypeSpec +from orcapod.types import TypeSpec from orcapod.types.semantic_type_registry import SemanticTypeRegistry from typing import Any import pyarrow as pa @@ -14,11 +13,13 @@ bool: pa.bool_(), } + def python_to_arrow_type(python_type: type) -> pa.DataType: if python_type in DEFAULT_ARROW_TYPE_LUT: return DEFAULT_ARROW_TYPE_LUT[python_type] raise TypeError(f"Converstion of python type {python_type} is not supported yet") + def arrow_to_python_type(arrow_type: pa.DataType) -> type: if pa.types.is_integer(arrow_type): return int @@ -38,68 +39,68 @@ def arrow_to_python_type(arrow_type: pa.DataType) -> type: raise TypeError(f"Conversion of arrow type {arrow_type} is not supported") - class PythonSchema(dict[str, type]): """ A schema for Python data types, mapping string keys to Python types. - + This is used to define the expected structure of data packets in OrcaPod. - + Attributes ---------- keys : str The keys of the schema. values : type The types corresponding to each key. - + Examples -------- >>> schema = PythonSchema(name=str, age=int) >>> print(schema) {'name': , 'age': } """ + @property def with_source_info(self) -> dict[str, type]: """ Get the schema with source info fields included. - + Returns ------- dict[str, type|None] A new schema including source info fields. """ - return {**self, **{f'_source_info_{k}': str for k in self.keys()}} + return {**self, **{f"_source_info_{k}": str for k in self.keys()}} - -class SemanticSchema(dict[str, tuple[type, str|None]]): +class SemanticSchema(dict[str, tuple[type, str | None]]): """ A schema for semantic types, mapping string keys to tuples of Python types and optional metadata. - + This is used to define the expected structure of data packets with semantic types in OrcaPod. - + Attributes ---------- keys : str The keys of the schema. values : tuple[type, str|None] The types and optional semantic type corresponding to each key. - + Examples -------- >>> schema = SemanticSchema(image=(str, 'path'), age=(int, None)) >>> print(schema) {'image': (, 'path'), 'age': (, None)} """ + def get_store_type(self, key: str) -> type | None: """ Get the storage type for a given key in the schema. - + Parameters ---------- key : str The key for which to retrieve the storage type. - + Returns ------- type | None @@ -110,24 +111,24 @@ def get_store_type(self, key: str) -> type | None: def get_semantic_type(self, key: str) -> str | None: """ Get the semantic type for a given key in the schema. - + Parameters ---------- key : str The key for which to retrieve the semantic type. - + Returns ------- str | None The semantic type associated with the key, or None if not found. """ return self.get(key, (None, None))[1] - + @property def storage_schema(self) -> PythonSchema: """ Get the storage schema, which is a PythonSchema representation of the semantic schema. - + Returns ------- PythonSchema @@ -135,17 +136,16 @@ def storage_schema(self) -> PythonSchema: """ return PythonSchema({k: v[0] for k, v in self.items()}) - @property def storage_schema_with_source_info(self) -> dict[str, type]: """ Get the storage schema with source info fields included. - + Returns ------- dict[str, type] A new schema including source info fields. - + Examples -------- >>> semantic_schema = SemanticSchema(name=(str, 'name'), age=(int, None)) @@ -162,19 +162,19 @@ def from_typespec_to_semantic_schema( ) -> SemanticSchema: """ Convert a Python schema to a semantic schema using the provided semantic type registry. - + Parameters ---------- typespec : TypeSpec The typespec to convert, mapping keys to Python types. semantic_type_registry : SemanticTypeRegistry The registry containing semantic type information. - + Returns ------- SemanticSchema A new schema mapping keys to tuples of Python types and optional semantic type identifiers. - + Examples -------- >>> typespec: TypeSpec = dict(name=str, age=int) @@ -186,31 +186,34 @@ def from_typespec_to_semantic_schema( for key, python_type in typespec.items(): if python_type in semantic_type_registry: type_info = semantic_type_registry.get_type_info(python_type) - assert type_info is not None, f"Type {python_type} should be found in the registry as `in` returned True" + assert type_info is not None, ( + f"Type {python_type} should be found in the registry as `in` returned True" + ) semantic_schema[key] = (type_info.storage_type, type_info.semantic_type) else: semantic_schema[key] = (python_type, None) return SemanticSchema(semantic_schema) + def from_semantic_schema_to_python_schema( semantic_schema: SemanticSchema, semantic_type_registry: SemanticTypeRegistry, ) -> PythonSchema: """ Convert a semantic schema to a Python schema using the provided semantic type registry. - + Parameters ---------- semantic_schema : SemanticSchema The schema to convert, mapping keys to tuples of Python types and optional semantic type identifiers. semantic_type_registry : SemanticTypeRegistry The registry containing semantic type information. - + Returns ------- PythonSchema A new schema mapping keys to Python types. - + Examples -------- >>> semantic_schema = SemanticSchema(name=(str, None), age=(int, None)) @@ -226,23 +229,24 @@ def from_semantic_schema_to_python_schema( python_schema_content[key] = python_type return PythonSchema(python_schema_content) + def from_semantic_schema_to_arrow_schema( semantic_schema: SemanticSchema, include_source_info: bool = True, ) -> pa.Schema: """ Convert a semantic schema to an Arrow schema. - + Parameters ---------- semantic_schema : SemanticSchema The schema to convert, mapping keys to tuples of Python types and optional semantic type identifiers. - + Returns ------- dict[str, type] A new schema mapping keys to Arrow-compatible types. - + Examples -------- >>> semantic_schema = SemanticSchema(name=(str, None), age=(int, None)) @@ -253,32 +257,39 @@ def from_semantic_schema_to_arrow_schema( fields = [] for field_name, (python_type, semantic_type) in semantic_schema.items(): arrow_type = DEFAULT_ARROW_TYPE_LUT[python_type] - field_metadata = {b"semantic_type": semantic_type.encode('utf-8')} if semantic_type else {} + field_metadata = ( + {b"semantic_type": semantic_type.encode("utf-8")} if semantic_type else {} + ) fields.append(pa.field(field_name, arrow_type, metadata=field_metadata)) if include_source_info: for field in semantic_schema: - field_metadata = {b'field_type': b'source_info'} - fields.append(pa.field(f'_source_info_{field}', pa.large_string(), metadata=field_metadata)) - + field_metadata = {b"field_type": b"source_info"} + fields.append( + pa.field( + f"_source_info_{field}", pa.large_string(), metadata=field_metadata + ) + ) + return pa.schema(fields) + def from_arrow_schema_to_semantic_schema( arrow_schema: pa.Schema, ) -> SemanticSchema: """ Convert an Arrow schema to a semantic schema. - + Parameters ---------- arrow_schema : pa.Schema The schema to convert, containing fields with metadata. - + Returns ------- SemanticSchema A new schema mapping keys to tuples of Python types and optional semantic type identifiers. - + Examples -------- >>> arrow_schema = pa.schema([pa.field('name', pa.string(), metadata={'semantic_type': 'name'}), @@ -289,19 +300,25 @@ def from_arrow_schema_to_semantic_schema( """ semantic_schema = {} for field in arrow_schema: - if field.metadata.get(b'field_type', b'') == b'source_info': + if field.metadata.get(b"field_type", b"") == b"source_info": # Skip source info fields continue - semantic_type = field.metadata.get(b'semantic_type', None) + semantic_type = field.metadata.get(b"semantic_type", None) semantic_type = semantic_type.decode() if semantic_type else None python_type = arrow_to_python_type(field.type) semantic_schema[field.name] = (python_type, semantic_type) return SemanticSchema(semantic_schema) -def from_typespec_to_arrow_schema(typespec: TypeSpec, - semantic_type_registry: SemanticTypeRegistry, include_source_info: bool = True) -> pa.Schema: + +def from_typespec_to_arrow_schema( + typespec: TypeSpec, + semantic_type_registry: SemanticTypeRegistry, + include_source_info: bool = True, +) -> pa.Schema: semantic_schema = from_typespec_to_semantic_schema(typespec, semantic_type_registry) - return from_semantic_schema_to_arrow_schema(semantic_schema, include_source_info=include_source_info) + return from_semantic_schema_to_arrow_schema( + semantic_schema, include_source_info=include_source_info + ) def from_arrow_schema_to_python_schema( @@ -310,17 +327,17 @@ def from_arrow_schema_to_python_schema( ) -> PythonSchema: """ Convert an Arrow schema to a Python schema. - + Parameters ---------- arrow_schema : pa.Schema The schema to convert, containing fields with metadata. - + Returns ------- PythonSchema A new schema mapping keys to Python types. - + Examples -------- >>> arrow_schema = pa.schema([pa.field('name', pa.string()), pa.field('age', pa.int64())]) @@ -329,4 +346,6 @@ def from_arrow_schema_to_python_schema( {'name': , 'age': } """ semantic_schema = from_arrow_schema_to_semantic_schema(arrow_schema) - return from_semantic_schema_to_python_schema(semantic_schema, semantic_type_registry) \ No newline at end of file + return from_semantic_schema_to_python_schema( + semantic_schema, semantic_type_registry + ) diff --git a/src/orcapod/types/semantic_type_registry.py b/src/orcapod/types/semantic_type_registry.py index d5a677f..2091904 100644 --- a/src/orcapod/types/semantic_type_registry.py +++ b/src/orcapod/types/semantic_type_registry.py @@ -33,7 +33,9 @@ def __init__(self): type, tuple[TypeHandler, str] ] = {} # PythonType -> (Handler, semantic_name) self._semantic_handlers: dict[str, TypeHandler] = {} # semantic_name -> Handler - self._semantic_to_python_lut: dict[str, type] = {} # semantic_name -> Python type + self._semantic_to_python_lut: dict[ + str, type + ] = {} # semantic_name -> Python type def register( self, @@ -49,7 +51,7 @@ def register( override: If True, allow overriding existing registration for the same semantic name and Python type(s) """ # Determine which types to register for - + python_type = handler.python_type() # Register handler for each type @@ -59,7 +61,7 @@ def register( raise ValueError( f"Type {python_type} already registered with semantic type '{existing_semantic}'" ) - + # Register by semantic name if semantic_type in self._semantic_handlers: raise ValueError(f"Semantic type '{semantic_type}' already registered") @@ -78,12 +80,12 @@ def lookup_handler_info(self, python_type: type) -> tuple[TypeHandler, str] | No if issubclass(python_type, registered_type): return (handler, semantic_type) return None - + def get_semantic_type(self, python_type: type) -> str | None: """Get semantic type for a Python type.""" handler_info = self.lookup_handler_info(python_type) return handler_info[1] if handler_info else None - + def get_handler(self, python_type: type) -> TypeHandler | None: """Get handler for a Python type.""" handler_info = self.lookup_handler_info(python_type) @@ -92,7 +94,6 @@ def get_handler(self, python_type: type) -> TypeHandler | None: def get_handler_by_semantic_type(self, semantic_type: str) -> TypeHandler | None: """Get handler by semantic type.""" return self._semantic_handlers.get(semantic_type) - def get_type_info(self, python_type: type) -> TypeInfo | None: """Get TypeInfo for a Python type.""" @@ -107,7 +108,6 @@ def get_type_info(self, python_type: type) -> TypeInfo | None: handler=handler, ) - def __contains__(self, python_type: type) -> bool: """Check if a Python type is registered.""" for registered_type in self._handlers: @@ -116,18 +116,14 @@ def __contains__(self, python_type: type) -> bool: return False - - - - # Below is a collection of functions that handles converting between various aspects of Python packets and Arrow tables. # Here for convenience, any Python dictionary with str keys and supported Python values are referred to as a packet. # Conversions are: -# python packet <-> storage packet <-> arrow table +# python packet <-> storage packet <-> arrow table # python typespec <-> storage typespec <-> arrow schema -# +# # python packet <-> storage packet requires the use of SemanticTypeRegistry # conversion between storage packet <-> arrow table requires info about semantic_type @@ -152,13 +148,13 @@ def __contains__(self, python_type: type) -> bool: # """Convert Arrow Schema to storage typespec and semantic type metadata.""" # typespec = {} # semantic_type_info = {} - + # for field in schema: # field_type = field.type # typespec[field.name] = field_type.to_pandas_dtype() # Convert Arrow type to Pandas dtype # if field.metadata and b"semantic_type" in field.metadata: # semantic_type_info[field.name] = field.metadata[b"semantic_type"].decode("utf-8") - + # return typespec, semantic_type_info @@ -168,14 +164,9 @@ def __contains__(self, python_type: type) -> bool: # semantic_type_info: dict[str, str] | None = None, - # # TypeSpec + TypeRegistry + ArrowLUT -> Arrow Schema (annotated with semantic_type) -# # - - - - +# # # # TypeSpec <-> Arrow Schema @@ -184,7 +175,7 @@ def __contains__(self, python_type: type) -> bool: # """Convert TypeSpec to PyArrow Schema.""" # if metadata_info is None: # metadata_info = {} - + # fields = [] # for field_name, field_type in typespec.items(): # type_info = registry.get_type_info(field_type) @@ -227,7 +218,6 @@ def __contains__(self, python_type: type) -> bool: # return keys_with_handlers, pa.schema(schema_fields) - # def arrow_table_to_packets( # table: pa.Table, # registry: SemanticTypeRegistry, @@ -347,14 +337,14 @@ def __contains__(self, python_type: type) -> bool: # bool: pa.bool_(), # bytes: pa.binary(), # } - + # if python_type in basic_mapping: # return basic_mapping[python_type] - + # # Handle generic types # origin = get_origin(python_type) # args = get_args(python_type) - + # if origin is list: # # Handle list[T] # if args: @@ -362,7 +352,7 @@ def __contains__(self, python_type: type) -> bool: # return pa.list_(element_type) # else: # return pa.list_(pa.large_string()) # default to list of strings - + # elif origin is dict: # # Handle dict[K, V] - PyArrow uses map type # if len(args) == 2: @@ -372,13 +362,13 @@ def __contains__(self, python_type: type) -> bool: # else: # # Otherwise default to using long string # return pa.map_(pa.large_string(), pa.large_string()) - + # elif origin is UnionType: # # Handle Optional[T] (Union[T, None]) # if len(args) == 2 and type(None) in args: # non_none_type = args[0] if args[1] is type(None) else args[1] # return python_to_pyarrow_type(non_none_type) - + # # Default fallback # if not strict: # logger.warning(f"Unsupported type {python_type}, defaulting to large_string") diff --git a/src/orcapod/types/typespec_utils.py b/src/orcapod/types/typespec_utils.py index 0786d10..4e48004 100644 --- a/src/orcapod/types/typespec_utils.py +++ b/src/orcapod/types/typespec_utils.py @@ -214,7 +214,6 @@ def extract_function_typespecs( return param_info, inferred_output_types - def get_typespec_from_dict(dict: Mapping) -> TypeSpec: """ Returns a TypeSpec for the given dictionary. @@ -248,7 +247,10 @@ def union_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | ) return merged -def intersection_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: + +def intersection_typespecs( + left: TypeSpec | None, right: TypeSpec | None +) -> TypeSpec | None: """ Returns the intersection of two TypeSpecs, only returning keys that are present in both. If a key is present in both TypeSpecs, the type must be the same. @@ -263,6 +265,8 @@ def intersection_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> Typ intersection[key] = get_compatible_type(left[key], right[key]) except TypeError: # If types are not compatible, raise an error - raise TypeError(f"Type conflict for key '{key}': {left[key]} vs {right[key]}") - - return intersection \ No newline at end of file + raise TypeError( + f"Type conflict for key '{key}': {left[key]} vs {right[key]}" + ) + + return intersection diff --git a/src/orcapod/utils/object_spec.py b/src/orcapod/utils/object_spec.py index f359a8c..dd09e1f 100644 --- a/src/orcapod/utils/object_spec.py +++ b/src/orcapod/utils/object_spec.py @@ -1,5 +1,6 @@ import importlib + def parse_objectspec(obj_spec: dict) -> Any: if "_class" in obj_spec: # if _class is specified, treat the dict as an object specification @@ -16,4 +17,4 @@ def parse_objectspec(obj_spec: dict) -> Any: parsed_object[k] = parse_objectspec(v) else: parsed_object[k] = v - return parsed_object \ No newline at end of file + return parsed_object diff --git a/src/orcapod/utils/stream_utils.py b/src/orcapod/utils/stream_utils.py index 5c5bb62..4246088 100644 --- a/src/orcapod/utils/stream_utils.py +++ b/src/orcapod/utils/stream_utils.py @@ -12,7 +12,6 @@ V = TypeVar("V") - def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: merged = left.copy() for key, right_value in right.items(): @@ -26,8 +25,6 @@ def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: return merged - - def common_elements(*values) -> Collection[str]: """ Returns the common keys between all lists of values. The identified common elements are @@ -57,7 +54,10 @@ def join_tags(tag1: Mapping[K, V], tag2: Mapping[K, V]) -> dict[K, V] | None: joined_tag[k] = v return joined_tag -def semijoin_tags(tag1: Mapping[K, V], tag2: Mapping[K, V], target_keys: Collection[K]|None = None) -> dict[K, V] | None: + +def semijoin_tags( + tag1: Mapping[K, V], tag2: Mapping[K, V], target_keys: Collection[K] | None = None +) -> dict[K, V] | None: """ Semijoin two tags. If the tags have the same key, the value must be the same or None will be returned. If all shared key's value match, tag1 would be returned @@ -72,6 +72,7 @@ def semijoin_tags(tag1: Mapping[K, V], tag2: Mapping[K, V], target_keys: Collect return None return dict(tag1) + def check_packet_compatibility(packet1: Packet, packet2: Packet) -> bool: """ Checks if two packets are compatible. If the packets have the same key, the value must be the same or False will be returned. diff --git a/tests/test_hashing/test_basic_composite_hasher.py b/tests/test_hashing/test_basic_composite_hasher.py index f2da406..a2d35a6 100644 --- a/tests/test_hashing/test_basic_composite_hasher.py +++ b/tests/test_hashing/test_basic_composite_hasher.py @@ -181,7 +181,9 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + algorithm=algorithm + ) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -193,7 +195,9 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + buffer_size=buffer_size + ) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -222,7 +226,9 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + algorithm=algorithm + ) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -234,7 +240,9 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + buffer_size=buffer_size + ) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -266,7 +274,9 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + algorithm=algorithm + ) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) @@ -285,7 +295,9 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite(buffer_size=buffer_size) + hasher = LegacyPathLikeHasherFactory.create_basic_legacy_composite( + buffer_size=buffer_size + ) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" diff --git a/tests/test_hashing/test_hasher_factory.py b/tests/test_hashing/test_hasher_factory.py index 69804a3..68daa3a 100644 --- a/tests/test_hashing/test_hasher_factory.py +++ b/tests/test_hashing/test_hasher_factory.py @@ -30,7 +30,9 @@ def test_create_file_hasher_without_cacher(self): def test_create_file_hasher_with_cacher(self): """Test creating a file hasher with string cacher (returns CachedFileHasher).""" cacher = InMemoryCacher() - hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(string_cacher=cacher) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( + string_cacher=cacher + ) # Should return LegacyCachedFileHasher assert isinstance(hasher, LegacyCachedFileHasher) @@ -56,13 +58,15 @@ def test_create_file_hasher_custom_algorithm(self): ) assert isinstance(hasher, LegacyCachedFileHasher) assert isinstance(hasher.file_hasher, LegacyDefaultFileHasher) - assert hasher.file_hasher.algorithm == "sha512" - assert hasher.file_hasher.buffer_size == 65536 + assert hasher.file_hasher.algorithm == "sha512" + assert hasher.file_hasher.buffer_size == 65536 def test_create_file_hasher_custom_buffer_size(self): """Test creating file hasher with custom buffer size.""" # Without cacher - hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(buffer_size=32768) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( + buffer_size=32768 + ) assert isinstance(hasher, LegacyDefaultFileHasher) assert hasher.algorithm == "sha256" assert hasher.buffer_size == 32768 @@ -94,7 +98,9 @@ def test_create_file_hasher_different_cacher_types(self): """Test creating file hasher with different types of string cachers.""" # InMemoryCacher memory_cacher = InMemoryCacher() - hasher1 = LegacyPathLikeHasherFactory.create_legacy_file_hasher(string_cacher=memory_cacher) + hasher1 = LegacyPathLikeHasherFactory.create_legacy_file_hasher( + string_cacher=memory_cacher + ) assert isinstance(hasher1, LegacyCachedFileHasher) assert hasher1.string_cacher is memory_cacher @@ -184,13 +190,17 @@ def test_create_file_hasher_parameter_edge_cases(self): assert hasher1.buffer_size == 1 # Large buffer size - hasher2 = LegacyPathLikeHasherFactory.create_legacy_file_hasher(buffer_size=1024 * 1024) + hasher2 = LegacyPathLikeHasherFactory.create_legacy_file_hasher( + buffer_size=1024 * 1024 + ) assert isinstance(hasher2, LegacyDefaultFileHasher) assert hasher2.buffer_size == 1024 * 1024 # Different algorithms for algorithm in ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]: - hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher(algorithm=algorithm) + hasher = LegacyPathLikeHasherFactory.create_legacy_file_hasher( + algorithm=algorithm + ) assert isinstance(hasher, LegacyDefaultFileHasher) assert hasher.algorithm == algorithm diff --git a/tests/test_hashing/test_legacy_composite_hasher.py b/tests/test_hashing/test_legacy_composite_hasher.py index f3a8de4..f234bb7 100644 --- a/tests/test_hashing/test_legacy_composite_hasher.py +++ b/tests/test_hashing/test_legacy_composite_hasher.py @@ -6,8 +6,15 @@ import pytest from orcapod.hashing.legacy_core import hash_to_hex -from orcapod.hashing.file_hashers import LegacyDefaultFileHasher, LegacyDefaultCompositeFileHasher -from orcapod.hashing.types import LegacyFileHasher, LegacyPacketHasher, LegacyPathSetHasher +from orcapod.hashing.file_hashers import ( + LegacyDefaultFileHasher, + LegacyDefaultCompositeFileHasher, +) +from orcapod.hashing.types import ( + LegacyFileHasher, + LegacyPacketHasher, + LegacyPathSetHasher, +) # Custom implementation of hash_file for tests that doesn't check for file existence @@ -90,7 +97,9 @@ def patch_hash_functions(): """Patch the hash functions in the core module for all tests.""" with ( patch("orcapod.hashing.legacy_core.hash_file", side_effect=mock_hash_file), - patch("orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset), + patch( + "orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset + ), patch("orcapod.hashing.legacy_core.hash_packet", side_effect=mock_hash_packet), ): yield diff --git a/tests/test_hashing/test_path_set_hasher.py b/tests/test_hashing/test_path_set_hasher.py index 0a48acb..c235eb0 100644 --- a/tests/test_hashing/test_path_set_hasher.py +++ b/tests/test_hashing/test_path_set_hasher.py @@ -86,7 +86,9 @@ def mock_hash_pathset( @pytest.fixture(autouse=True) def patch_hash_pathset(): """Patch the hash_pathset function in the hashing module for all tests.""" - with patch("orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset): + with patch( + "orcapod.hashing.legacy_core.hash_pathset", side_effect=mock_hash_pathset + ): yield @@ -225,7 +227,8 @@ def custom_hash_nonexistent(pathset, **kwargs): # Patch hash_pathset just for this test with patch( - "orcapod.hashing.legacy_core.hash_pathset", side_effect=custom_hash_nonexistent + "orcapod.hashing.legacy_core.hash_pathset", + side_effect=custom_hash_nonexistent, ): result = pathset_hasher.hash_pathset(pathset) diff --git a/tests/test_hashing/test_string_cacher/test_redis_cacher.py b/tests/test_hashing/test_string_cacher/test_redis_cacher.py index 3ef49e1..eef7c43 100644 --- a/tests/test_hashing/test_string_cacher/test_redis_cacher.py +++ b/tests/test_hashing/test_string_cacher/test_redis_cacher.py @@ -68,21 +68,21 @@ def keys(self, pattern): return [key for key in self.data.keys() if key.startswith(prefix)] return [key for key in self.data.keys() if key == pattern] + class MockRedisModule: ConnectionError = MockConnectionError RedisError = MockRedisError Redis = MagicMock(return_value=MockRedis()) # Simple one-liner! - def mock_get_redis(): return MockRedisModule + def mock_no_redis(): return None - class TestRedisCacher: """Test cases for RedisCacher with mocked Redis.""" From cbe82aba171a35d21dc4c29aae7113c8c7c9f107 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 16:31:50 +0000 Subject: [PATCH 35/57] fix: legacy_core imports --- tests/test_hashing/test_file_hashes.py | 2 +- tests/test_hashing/test_hash_samples.py | 2 +- tests/test_hashing/test_pathset_and_packet.py | 2 +- tests/test_hashing/test_pathset_packet_hashes.py | 2 +- tests/test_store/test_dir_data_store.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_hashing/test_file_hashes.py b/tests/test_hashing/test_file_hashes.py index 1de0716..afcaaad 100644 --- a/tests/test_hashing/test_file_hashes.py +++ b/tests/test_hashing/test_file_hashes.py @@ -12,7 +12,7 @@ import pytest # Add the parent directory to the path to import orcapod -from orcapod.hashing import hash_file +from orcapod.hashing.legacy_core import hash_file def load_hash_lut(): diff --git a/tests/test_hashing/test_hash_samples.py b/tests/test_hashing/test_hash_samples.py index cfb3e35..1e536cb 100644 --- a/tests/test_hashing/test_hash_samples.py +++ b/tests/test_hashing/test_hash_samples.py @@ -12,7 +12,7 @@ import pytest -from orcapod.hashing import hash_to_hex, hash_to_int, hash_to_uuid +from orcapod.hashing.legacy_core import hash_to_hex, hash_to_int, hash_to_uuid def get_latest_hash_samples(): diff --git a/tests/test_hashing/test_pathset_and_packet.py b/tests/test_hashing/test_pathset_and_packet.py index fc00b29..cde79da 100644 --- a/tests/test_hashing/test_pathset_and_packet.py +++ b/tests/test_hashing/test_pathset_and_packet.py @@ -13,7 +13,7 @@ import pytest -from orcapod.hashing import hash_file, hash_packet, hash_pathset +from orcapod.hashing.legacy_core import hash_file, hash_packet, hash_pathset logger = logging.getLogger(__name__) diff --git a/tests/test_hashing/test_pathset_packet_hashes.py b/tests/test_hashing/test_pathset_packet_hashes.py index 7745881..7df740d 100644 --- a/tests/test_hashing/test_pathset_packet_hashes.py +++ b/tests/test_hashing/test_pathset_packet_hashes.py @@ -12,7 +12,7 @@ import pytest # Add the parent directory to the path to import orcapod -from orcapod.hashing import hash_packet, hash_pathset +from orcapod.hashing.legacy_core import hash_packet, hash_pathset def load_pathset_hash_lut(): diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index eae39eb..09d84d7 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -499,7 +499,7 @@ def test_dir_data_store_legacy_mode_compatibility(temp_dir, sample_files): output_packet = {"output_file": sample_files["output"]["output1"]} # Get the hash values directly for comparison - from orcapod.hashing import hash_packet + from orcapod.hashing.legacy_core import hash_packet legacy_hash = hash_packet(packet, algorithm="sha256") assert store_default.packet_hasher is not None, ( @@ -610,7 +610,7 @@ def test_dir_data_store_hash_equivalence(temp_dir, sample_files): output_packet = {"output_file": sample_files["output"]["output1"]} # First compute hashes directly - from orcapod.hashing import hash_packet + from orcapod.hashing.legacy_core import hash_packet from orcapod.hashing.defaults import get_default_composite_file_hasher legacy_hash = hash_packet(packet, algorithm="sha256") From caca67b5ea0444cfcbfdaaf5db25e330dcb841dd Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 19:39:51 +0000 Subject: [PATCH 36/57] wip: arrow logical serialization --- src/orcapod/hashing/arrow_serialization.py | 822 +++++++++++++++++++++ 1 file changed, 822 insertions(+) create mode 100644 src/orcapod/hashing/arrow_serialization.py diff --git a/src/orcapod/hashing/arrow_serialization.py b/src/orcapod/hashing/arrow_serialization.py new file mode 100644 index 0000000..e4926cb --- /dev/null +++ b/src/orcapod/hashing/arrow_serialization.py @@ -0,0 +1,822 @@ +import pyarrow as pa +from io import BytesIO +import pyarrow.ipc as ipc +import struct +from typing import Any +import hashlib + + +def serialize_table_ipc(table: pa.Table) -> bytes: + # TODO: fix and use logical table hashing instead + """Serialize table using Arrow IPC format for stable binary representation.""" + buffer = BytesIO() + + # Write format version + buffer.write(b"ARROW_IPC_V1") + + # Use IPC stream format for deterministic serialization + with ipc.new_stream(buffer, table.schema) as writer: + writer.write_table(table) + + return buffer.getvalue() + + +def serialize_table_logical(table: pa.Table) -> bytes: + """ + Serialize table using column-wise processing with direct binary data access. + + This implementation works directly with Arrow's underlying binary buffers + without converting to Python objects, making it much faster and more + memory efficient while maintaining high repeatability. + """ + buffer = BytesIO() + + # Write format version + buffer.write(b"ARROW_BINARY_V1") + + # Serialize schema deterministically + _serialize_schema_deterministic(buffer, table.schema) + + # Process each column using direct binary access + column_digests = [] + for i in range(table.num_columns): + column = table.column(i) + field = table.schema.field(i) + column_digest = _serialize_column_binary(column, field) + column_digests.append(column_digest) + + # Combine column digests + for digest in column_digests: + buffer.write(digest) + + return buffer.getvalue() + + +def _serialize_schema_deterministic(buffer: BytesIO, schema: pa.Schema) -> None: + """Serialize schema information deterministically.""" + buffer.write(struct.pack(" None: + """Serialize Arrow data type deterministically.""" + type_id = data_type.id + buffer.write(struct.pack(" bytes: + """ + Serialize column using direct binary buffer access. + + To ensure chunking independence, we combine chunks into a single array + before processing. This ensures identical output regardless of chunk boundaries. + """ + buffer = BytesIO() + + # Combine all chunks into a single array for consistent processing + if column.num_chunks > 1: + # Multiple chunks - combine them + combined_array = pa.concat_arrays(column.chunks) + elif column.num_chunks == 1: + # Single chunk - use directly + combined_array = column.chunk(0) + else: + # No chunks - create empty array + combined_array = pa.array([], type=field.type) + + # Process the combined array + chunk_result = _serialize_array_binary(combined_array, field.type) + buffer.write(chunk_result) + + return buffer.getvalue() + + +def _serialize_array_binary(array: pa.Array, data_type: pa.DataType) -> bytes: + """Serialize array using direct access to Arrow's binary buffers.""" + buffer = BytesIO() + + # Get validity buffer (null bitmap) if it exists + validity_buffer = None + if array.buffers()[0] is not None: + validity_buffer = array.buffers()[0] + + # Process based on Arrow type, accessing buffers directly + if _is_primitive_type(data_type): + _serialize_primitive_array_binary(buffer, array, data_type, validity_buffer) + + elif pa.types.is_string(data_type) or pa.types.is_large_string(data_type): + _serialize_string_array_binary(buffer, array, data_type, validity_buffer) + + elif pa.types.is_binary(data_type) or pa.types.is_large_binary(data_type): + _serialize_binary_array_binary(buffer, array, data_type, validity_buffer) + + elif pa.types.is_list(data_type) or pa.types.is_large_list(data_type): + _serialize_list_array_binary(buffer, array, data_type, validity_buffer) + + elif pa.types.is_struct(data_type): + _serialize_struct_array_binary(buffer, array, data_type, validity_buffer) + + elif pa.types.is_dictionary(data_type): + _serialize_dictionary_array_binary(buffer, array, data_type, validity_buffer) + + else: + # Fallback to element-wise processing for complex types + _serialize_array_fallback(buffer, array, data_type, validity_buffer) + + return buffer.getvalue() + + +def _is_primitive_type(data_type: pa.DataType) -> bool: + """Check if type can be processed as primitive (fixed-size) data.""" + return ( + pa.types.is_integer(data_type) + or pa.types.is_floating(data_type) + or pa.types.is_boolean(data_type) + or pa.types.is_date(data_type) + or pa.types.is_time(data_type) + or pa.types.is_timestamp(data_type) + ) + + +def _serialize_primitive_array_binary( + buffer: BytesIO, array: pa.Array, data_type: pa.DataType, validity_buffer +): + """Serialize primitive arrays by directly copying binary data.""" + # Write validity bitmap + _serialize_validity_buffer(buffer, validity_buffer) + + # Get data buffer (buffer[1] for primitive types) + data_buffer = array.buffers()[1] + if data_buffer is not None: + # For primitive types, copy the buffer directly + if pa.types.is_boolean(data_type): + # Boolean needs the length for bit interpretation + buffer.write(struct.pack(" 0: + child_array = array.children[0] + + # Recursively serialize child array + if child_array is not None: + child_data = _serialize_array_binary(child_array, data_type.value_type) + buffer.write(child_data) + + +def _serialize_struct_array_binary( + buffer: BytesIO, array: pa.Array, data_type: pa.DataType, validity_buffer +): + """Serialize struct arrays by processing child arrays.""" + # Write validity bitmap + _serialize_validity_buffer(buffer, validity_buffer) + + # Serialize each child field + for i, child_array in enumerate(array.children): + field_type = data_type[i].type + child_data = _serialize_array_binary(child_array, field_type) + buffer.write(child_data) + + +def _serialize_dictionary_array_binary( + buffer: BytesIO, array: pa.Array, data_type: pa.DataType, validity_buffer +): + """Serialize dictionary arrays using indices + dictionary.""" + # Write validity bitmap + _serialize_validity_buffer(buffer, validity_buffer) + + # Serialize indices array + indices_data = _serialize_array_binary(array.indices, data_type.index_type) + buffer.write(indices_data) + + # Serialize dictionary array + dict_data = _serialize_array_binary(array.dictionary, data_type.value_type) + buffer.write(dict_data) + + +def _serialize_validity_buffer(buffer: BytesIO, validity_buffer): + """Serialize validity (null) bitmap.""" + if validity_buffer is not None: + # Copy validity bitmap directly + buffer.write(validity_buffer.to_pybytes()) + # If no validity buffer, there are no nulls (implicit) + + +def _serialize_boolean_buffer(buffer: BytesIO, data_buffer, array_length: int): + """Serialize boolean buffer (bit-packed).""" + # Boolean data is bit-packed, copy directly + bool_bytes = data_buffer.to_pybytes() + buffer.write(struct.pack(" int: + """Get byte width of primitive types.""" + if pa.types.is_boolean(data_type): + return 1 # Bit-packed, but minimum 1 byte + elif pa.types.is_integer(data_type) or pa.types.is_floating(data_type): + return data_type.bit_width // 8 + elif pa.types.is_date(data_type): + return 4 if data_type == pa.date32() else 8 + elif pa.types.is_time(data_type) or pa.types.is_timestamp(data_type): + return data_type.bit_width // 8 + else: + return 8 # Default + + +def _serialize_array_fallback( + buffer: BytesIO, array: pa.Array, data_type: pa.DataType, validity_buffer +): + """Fallback to element-wise processing for complex types.""" + # Write validity bitmap + _serialize_validity_buffer(buffer, validity_buffer) + + # Process element by element (only for types that need it) + for i in range(len(array)): + if array.is_null(i): + buffer.write(b"\x00") + else: + buffer.write(b"\x01") + # For complex nested types, we might still need .as_py() + # But this should be rare with proper binary handling above + value = array[i].as_py() + _serialize_complex_value(buffer, value, data_type) + + +def _serialize_complex_value(buffer: BytesIO, value: Any, data_type: pa.DataType): + """Serialize complex values that can't be handled by direct buffer access.""" + # This handles edge cases like nested structs with mixed types + if pa.types.is_decimal(data_type): + decimal_str = str(value).encode("utf-8") + buffer.write(struct.pack(" str: + """Create deterministic hash using binary serialization.""" + serialized = serialize_table_logical(table) + + if algorithm == "sha256": + hasher = hashlib.sha256() + elif algorithm == "sha3_256": + hasher = hashlib.sha3_256() + elif algorithm == "blake2b": + hasher = hashlib.blake2b() + else: + raise ValueError(f"Unsupported hash algorithm: {algorithm}") + + hasher.update(serialized) + return hasher.hexdigest() + + +def serialize_table_logical_streaming(table: pa.Table) -> str: + """ + Memory-efficient streaming version that produces the same hash as serialize_table_logical_hash. + + This version processes data in streaming fashion but maintains the same logical structure + as the non-streaming version to ensure identical hashes and chunking independence. + """ + hasher = hashlib.sha256() + + # Hash format version (same as non-streaming) + hasher.update(b"ARROW_BINARY_V1") + + # Hash schema (same as non-streaming) + schema_buffer = BytesIO() + _serialize_schema_deterministic(schema_buffer, table.schema) + hasher.update(schema_buffer.getvalue()) + + # Process each column using the same logic as non-streaming + for i in range(table.num_columns): + column = table.column(i) + field = table.schema.field(i) + + # Use the same column serialization logic for chunking independence + column_data = _serialize_column_binary(column, field) + + # Hash the column data + hasher.update(column_data) + + return hasher.hexdigest() + + +# Test utilities +def create_test_table_1(): + """Create a basic test table with various data types.""" + return pa.table( + { + "int32_col": pa.array([1, 2, None, 4, 5], type=pa.int32()), + "float64_col": pa.array([1.1, 2.2, 3.3, None, 5.5], type=pa.float64()), + "string_col": pa.array(["hello", "world", None, "arrow", "fast"]), + "bool_col": pa.array([True, False, None, True, False]), + "binary_col": pa.array([b"data1", b"data2", None, b"data4", b"data5"]), + } + ) + + +def create_test_table_reordered_columns(): + """Same data as test_table_1 but with different column order.""" + return pa.table( + { + "string_col": pa.array(["hello", "world", None, "arrow", "fast"]), + "bool_col": pa.array([True, False, None, True, False]), + "int32_col": pa.array([1, 2, None, 4, 5], type=pa.int32()), + "binary_col": pa.array([b"data1", b"data2", None, b"data4", b"data5"]), + "float64_col": pa.array([1.1, 2.2, 3.3, None, 5.5], type=pa.float64()), + } + ) + + +def create_test_table_reordered_rows(): + """Same data as test_table_1 but with different row order.""" + return pa.table( + { + "int32_col": pa.array([5, 4, None, 2, 1], type=pa.int32()), + "float64_col": pa.array([5.5, None, 3.3, 2.2, 1.1], type=pa.float64()), + "string_col": pa.array(["fast", "arrow", None, "world", "hello"]), + "bool_col": pa.array([False, True, None, False, True]), + "binary_col": pa.array([b"data5", b"data4", None, b"data2", b"data1"]), + } + ) + + +def create_test_table_different_types(): + """Same logical data but with different Arrow types where possible.""" + return pa.table( + { + "int32_col": pa.array( + [1, 2, None, 4, 5], type=pa.int64() + ), # int64 instead of int32 + "float64_col": pa.array( + [1.1, 2.2, 3.3, None, 5.5], type=pa.float32() + ), # float32 instead of float64 + "string_col": pa.array(["hello", "world", None, "arrow", "fast"]), + "bool_col": pa.array([True, False, None, True, False]), + "binary_col": pa.array([b"data1", b"data2", None, b"data4", b"data5"]), + } + ) + + +def create_test_table_different_chunking(): + """Same data as test_table_1 but with different chunking.""" + # Create arrays with explicit chunking + int_chunks = [ + pa.array([1, 2], type=pa.int32()), + pa.array([None, 4, 5], type=pa.int32()), + ] + float_chunks = [ + pa.array([1.1], type=pa.float64()), + pa.array([2.2, 3.3, None, 5.5], type=pa.float64()), + ] + string_chunks = [pa.array(["hello", "world"]), pa.array([None, "arrow", "fast"])] + bool_chunks = [pa.array([True, False, None]), pa.array([True, False])] + binary_chunks = [ + pa.array([b"data1"]), + pa.array([b"data2", None, b"data4", b"data5"]), + ] + + return pa.table( + { + "int32_col": pa.chunked_array(int_chunks), + "float64_col": pa.chunked_array(float_chunks), + "string_col": pa.chunked_array(string_chunks), + "bool_col": pa.chunked_array(bool_chunks), + "binary_col": pa.chunked_array(binary_chunks), + } + ) + + +def create_test_table_empty(): + """Create an empty table with same schema.""" + return pa.table( + { + "int32_col": pa.array([], type=pa.int32()), + "float64_col": pa.array([], type=pa.float64()), + "string_col": pa.array([], type=pa.string()), + "bool_col": pa.array([], type=pa.bool_()), + "binary_col": pa.array([], type=pa.binary()), + } + ) + + +def create_test_table_all_nulls(): + """Create a table with all null values.""" + return pa.table( + { + "int32_col": pa.array([None, None, None], type=pa.int32()), + "float64_col": pa.array([None, None, None], type=pa.float64()), + "string_col": pa.array([None, None, None], type=pa.string()), + "bool_col": pa.array([None, None, None], type=pa.bool_()), + "binary_col": pa.array([None, None, None], type=pa.binary()), + } + ) + + +def create_test_table_no_nulls(): + """Create a table with no null values.""" + return pa.table( + { + "int32_col": pa.array([1, 2, 3, 4, 5], type=pa.int32()), + "float64_col": pa.array([1.1, 2.2, 3.3, 4.4, 5.5], type=pa.float64()), + "string_col": pa.array(["hello", "world", "arrow", "fast", "data"]), + "bool_col": pa.array([True, False, True, False, True]), + "binary_col": pa.array([b"data1", b"data2", b"data3", b"data4", b"data5"]), + } + ) + + +def create_test_table_complex_types(): + """Create a table with complex nested types.""" + return pa.table( + { + "list_col": pa.array( + [[1, 2], [3, 4, 5], None, [], [6]], type=pa.list_(pa.int32()) + ), + "struct_col": pa.array( + [ + {"a": 1, "b": "x"}, + {"a": 2, "b": "y"}, + None, + {"a": 3, "b": "z"}, + {"a": 4, "b": "w"}, + ], + type=pa.struct([("a", pa.int32()), ("b", pa.string())]), + ), + "dict_col": pa.array( + ["apple", "banana", "apple", None, "cherry"] + ).dictionary_encode(), + } + ) + + +def create_test_table_single_column(): + """Create a table with just one column.""" + return pa.table({"single_col": pa.array([1, 2, 3, 4, 5], type=pa.int32())}) + + +def create_test_table_single_row(): + """Create a table with just one row.""" + return pa.table( + { + "int32_col": pa.array([42], type=pa.int32()), + "string_col": pa.array(["single"]), + "bool_col": pa.array([True]), + } + ) + + +def run_comprehensive_tests(): + """Run comprehensive test suite for serialization.""" + import time + + print("=" * 60) + print("COMPREHENSIVE ARROW SERIALIZATION TEST SUITE") + print("=" * 60) + + # Test cases + test_cases = [ + ("Basic table", create_test_table_1), + ("Reordered columns", create_test_table_reordered_columns), + ("Reordered rows", create_test_table_reordered_rows), + ("Different types", create_test_table_different_types), + ("Different chunking", create_test_table_different_chunking), + ("Empty table", create_test_table_empty), + ("All nulls", create_test_table_all_nulls), + ("No nulls", create_test_table_no_nulls), + ("Complex types", create_test_table_complex_types), + ("Single column", create_test_table_single_column), + ("Single row", create_test_table_single_row), + ] + + # Generate hashes for all test cases + results = {} + + print("\n1. GENERATING HASHES FOR ALL TEST CASES") + print("-" * 50) + + for name, create_func in test_cases: + try: + table = create_func() + + # Generate all hash types + logical_hash = serialize_table_logical_hash(table) + streaming_hash = serialize_table_logical_streaming(table) + ipc_hash = hashlib.sha256(serialize_table_ipc(table)).hexdigest() + + results[name] = { + "table": table, + "logical": logical_hash, + "streaming": streaming_hash, + "ipc": ipc_hash, + "rows": table.num_rows, + "cols": table.num_columns, + } + + print( + f"{name:20} | Rows: {table.num_rows:5} | Cols: {table.num_columns:2} | " + f"Logical: {logical_hash[:12]}... | IPC: {ipc_hash[:12]}..." + ) + + except Exception as e: + print(f"{name:20} | ERROR: {str(e)}") + results[name] = {"error": str(e)} + + print("\n2. DETERMINISM TESTS") + print("-" * 50) + + base_table = create_test_table_1() + + # Test multiple runs of same table + logical_hashes = [serialize_table_logical_hash(base_table) for _ in range(5)] + streaming_hashes = [serialize_table_logical_streaming(base_table) for _ in range(5)] + ipc_hashes = [ + hashlib.sha256(serialize_table_ipc(base_table)).hexdigest() for _ in range(5) + ] + + print( + f"Logical deterministic: {len(set(logical_hashes)) == 1} ({len(set(logical_hashes))}/5 unique)" + ) + print( + f"Streaming deterministic: {len(set(streaming_hashes)) == 1} ({len(set(streaming_hashes))}/5 unique)" + ) + print( + f"IPC deterministic: {len(set(ipc_hashes)) == 1} ({len(set(ipc_hashes))}/5 unique)" + ) + print(f"Streaming == Logical: {streaming_hashes[0] == logical_hashes[0]}") + + print("\n3. EQUIVALENCE TESTS") + print("-" * 50) + + base_logical = results["Basic table"]["logical"] + base_ipc = results["Basic table"]["ipc"] + + equivalence_tests = [ + ( + "Same table vs reordered columns", + "Reordered columns", + False, + "Different column order should produce different hash", + ), + ( + "Same table vs reordered rows", + "Reordered rows", + False, + "Different row order should produce different hash", + ), + ( + "Same table vs different types", + "Different types", + False, + "Different data types should produce different hash", + ), + ( + "Same table vs different chunking", + "Different chunking", + True, + "Same data with different chunking should produce same hash", + ), + ( + "Same table vs no nulls", + "No nulls", + False, + "Different null patterns should produce different hash", + ), + ( + "Same table vs all nulls", + "All nulls", + False, + "Different data should produce different hash", + ), + ] + + for test_name, compare_case, should_match, explanation in equivalence_tests: + if compare_case in results and "logical" in results[compare_case]: + compare_logical = results[compare_case]["logical"] + compare_ipc = results[compare_case]["ipc"] + + logical_match = base_logical == compare_logical + ipc_match = base_ipc == compare_ipc + + logical_status = "✓" if logical_match == should_match else "✗" + ipc_status = "✓" if ipc_match == should_match else "✗" + + print(f"{logical_status} {test_name}") + print(f" Logical: {logical_match} (expected: {should_match})") + print(f" IPC: {ipc_match} (expected: {should_match})") + print(f" Reason: {explanation}") + print() + + print("4. CHUNKING INDEPENDENCE DETAILED TEST") + print("-" * 50) + + # Test various chunking strategies + original_table = create_test_table_1() + combined_table = original_table.combine_chunks() + different_chunking = create_test_table_different_chunking() + + orig_logical = serialize_table_logical_hash(original_table) + comb_logical = serialize_table_logical_hash(combined_table) + diff_logical = serialize_table_logical_hash(different_chunking) + + orig_ipc = hashlib.sha256(serialize_table_ipc(original_table)).hexdigest() + comb_ipc = hashlib.sha256(serialize_table_ipc(combined_table)).hexdigest() + diff_ipc = hashlib.sha256(serialize_table_ipc(different_chunking)).hexdigest() + + print(f"Original chunking: {orig_logical[:16]}...") + print(f"Combined chunks: {comb_logical[:16]}...") + print(f"Different chunking: {diff_logical[:16]}...") + print( + f"Logical chunking-independent: {orig_logical == comb_logical == diff_logical}" + ) + print() + print(f"Original IPC: {orig_ipc[:16]}...") + print(f"Combined IPC: {comb_ipc[:16]}...") + print(f"Different IPC: {diff_ipc[:16]}...") + print(f"IPC chunking-independent: {orig_ipc == comb_ipc == diff_ipc}") + + print("\n5. PERFORMANCE COMPARISON") + print("-" * 50) + + # Create larger table for performance testing + large_size = 10000 + large_table = pa.table( + { + "int_col": pa.array(list(range(large_size)), type=pa.int32()), + "float_col": pa.array( + [i * 1.5 for i in range(large_size)], type=pa.float64() + ), + "string_col": pa.array([f"item_{i}" for i in range(large_size)]), + "bool_col": pa.array([i % 2 == 0 for i in range(large_size)]), + } + ) + + # Time each method + methods = [ + ("Logical", lambda t: serialize_table_logical_hash(t)), + ("Streaming", lambda t: serialize_table_logical_streaming(t)), + ("IPC", lambda t: hashlib.sha256(serialize_table_ipc(t)).hexdigest()), + ] + hash_result = "" + for method_name, method_func in methods: + times = [] + for _ in range(3): # Run 3 times for average + start = time.time() + hash_result = method_func(large_table) + end = time.time() + times.append(end - start) + + avg_time = sum(times) / len(times) + throughput = (large_size * 4) / avg_time # 4 columns + + print( + f"{method_name:10} | {avg_time * 1000:6.1f}ms | {throughput:8.0f} values/sec | {hash_result[:12]}..." + ) + + print("\n6. EDGE CASES") + print("-" * 50) + + edge_cases = ["Empty table", "All nulls", "Single column", "Single row"] + for case in edge_cases: + if case in results and "error" not in results[case]: + r = results[case] + print( + f"{case:15} | {r['rows']:3}r x {r['cols']:2}c | " + f"L:{r['logical'][:8]}... | I:{r['ipc'][:8]}... | " + f"Match: {r['logical'] == r['streaming']}" + ) + + print("\n7. COMPLEX TYPES TEST") + print("-" * 50) + + if "Complex types" in results and "error" not in results["Complex types"]: + complex_result = results["Complex types"] + print(f"Complex types serialization successful:") + print(f" Logical hash: {complex_result['logical']}") + print( + f" Streaming ==: {complex_result['logical'] == complex_result['streaming']}" + ) + print(f" Rows/Cols: {complex_result['rows']}r x {complex_result['cols']}c") + else: + print( + "Complex types test failed - this is expected for some complex nested types" + ) + + print(f"\n{'=' * 60}") + print("TEST SUITE COMPLETE") + print(f"{'=' * 60}") + + return results + + +# Main execution +if __name__ == "__main__": + # Run the comprehensive test suite + test_results = run_comprehensive_tests() From 7bc98e1255b043f9da8a151bd7e3053e27a526b5 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 1 Jul 2025 19:40:31 +0000 Subject: [PATCH 37/57] refactor: utils renaming and relocation --- .../stores/{file_ops.py => file_utils.py} | 6 +- src/orcapod/stores/safe_dir_data_store.py | 4 +- src/orcapod/types/utils.py | 62 ------------------ src/orcapod/utils/name.py | 64 +++++++++++++++++++ 4 files changed, 69 insertions(+), 67 deletions(-) rename src/orcapod/stores/{file_ops.py => file_utils.py} (99%) delete mode 100644 src/orcapod/types/utils.py diff --git a/src/orcapod/stores/file_ops.py b/src/orcapod/stores/file_utils.py similarity index 99% rename from src/orcapod/stores/file_ops.py rename to src/orcapod/stores/file_utils.py index 4fa6202..34380e0 100644 --- a/src/orcapod/stores/file_ops.py +++ b/src/orcapod/stores/file_utils.py @@ -7,7 +7,7 @@ import os from pathlib import Path -from orcapod.types import PathLike, PathSet, Packet +from orcapod.types import PathLike, PathSet, PacketLike from collections.abc import Collection, Callable @@ -369,8 +369,8 @@ def patched_open(file, *args, **kwargs): def virtual_mount( - packet: Packet, -) -> tuple[Packet, dict[str, str], dict[str, str]]: + packet: PacketLike, +) -> tuple[PacketLike, dict[str, str], dict[str, str]]: """ Visit all pathset within the packet, and convert them to alternative path representation. By default, full path is mapped to the file name. If two or diff --git a/src/orcapod/stores/safe_dir_data_store.py b/src/orcapod/stores/safe_dir_data_store.py index 7e16f63..e02e9cc 100644 --- a/src/orcapod/stores/safe_dir_data_store.py +++ b/src/orcapod/stores/safe_dir_data_store.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Optional, Union -from .file_ops import atomic_copy, atomic_write +from .file_utils import atomic_copy, atomic_write logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class FileLockError(Exception): @contextmanager def file_lock( - lock_path: Union[str, Path], + lock_path: str | Path, shared: bool = False, timeout: float = 30.0, delay: float = 0.1, diff --git a/src/orcapod/types/utils.py b/src/orcapod/types/utils.py deleted file mode 100644 index 5393492..0000000 --- a/src/orcapod/types/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -# TODO: move these functions to util -def escape_with_postfix(field: str, postfix=None, separator="_") -> str: - """ - Escape the field string by doubling separators and optionally append a postfix. - This function takes a field string and escapes any occurrences of the separator - by doubling them, then optionally appends a postfix with a separator prefix. - - Args: - field (str): The input string containing to be escaped. - postfix (str, optional): An optional postfix to append to the escaped string. - If None, no postfix is added. Defaults to None. - separator (str, optional): The separator character to escape and use for - prefixing the postfix. Defaults to "_". - Returns: - str: The escaped string with optional postfix. Returns empty string if - fields is provided but postfix is None. - Examples: - >>> escape_with_postfix("field1_field2", "suffix") - 'field1__field2_suffix' - >>> escape_with_postfix("name_age_city", "backup", "_") - 'name__age__city_backup' - >>> escape_with_postfix("data-info", "temp", "-") - 'data--info-temp' - >>> escape_with_postfix("simple", None) - 'simple' - >>> escape_with_postfix("no_separators", "end") - 'no__separators_end' - """ - - return field.replace(separator, separator * 2) + (f"_{postfix}" if postfix else "") - - -def unescape_with_postfix(field: str, separator="_") -> tuple[str, str | None]: - """ - Unescape a string by converting double separators back to single separators and extract postfix metadata. - This function reverses the escaping process where single separators were doubled to avoid - conflicts with metadata delimiters. It splits the input on double separators, then extracts - any postfix metadata from the last part. - - Args: - field (str): The escaped string containing doubled separators and optional postfix metadata - separator (str, optional): The separator character used for escaping. Defaults to "_" - Returns: - tuple[str, str | None]: A tuple containing: - - The unescaped string with single separators restored - - The postfix metadata if present, None otherwise - Examples: - >>> unescape_with_postfix("field1__field2__field3") - ('field1_field2_field3', None) - >>> unescape_with_postfix("field1__field2_metadata") - ('field1_field2', 'metadata') - >>> unescape_with_postfix("simple") - ('simple', None) - >>> unescape_with_postfix("field1--field2", separator="-") - ('field1-field2', None) - >>> unescape_with_postfix("field1--field2-meta", separator="-") - ('field1-field2', 'meta') - """ - - parts = field.split(separator * 2) - parts[-1], *meta = parts[-1].split("_", 1) - return separator.join(parts), meta[0] if meta else None diff --git a/src/orcapod/utils/name.py b/src/orcapod/utils/name.py index ba2c4f0..2211ef6 100644 --- a/src/orcapod/utils/name.py +++ b/src/orcapod/utils/name.py @@ -5,6 +5,70 @@ import re +# TODO: move these functions to util +def escape_with_postfix(field: str, postfix=None, separator="_") -> str: + """ + Escape the field string by doubling separators and optionally append a postfix. + This function takes a field string and escapes any occurrences of the separator + by doubling them, then optionally appends a postfix with a separator prefix. + + Args: + field (str): The input string containing to be escaped. + postfix (str, optional): An optional postfix to append to the escaped string. + If None, no postfix is added. Defaults to None. + separator (str, optional): The separator character to escape and use for + prefixing the postfix. Defaults to "_". + Returns: + str: The escaped string with optional postfix. Returns empty string if + fields is provided but postfix is None. + Examples: + >>> escape_with_postfix("field1_field2", "suffix") + 'field1__field2_suffix' + >>> escape_with_postfix("name_age_city", "backup", "_") + 'name__age__city_backup' + >>> escape_with_postfix("data-info", "temp", "-") + 'data--info-temp' + >>> escape_with_postfix("simple", None) + 'simple' + >>> escape_with_postfix("no_separators", "end") + 'no__separators_end' + """ + + return field.replace(separator, separator * 2) + (f"_{postfix}" if postfix else "") + + +def unescape_with_postfix(field: str, separator="_") -> tuple[str, str | None]: + """ + Unescape a string by converting double separators back to single separators and extract postfix metadata. + This function reverses the escaping process where single separators were doubled to avoid + conflicts with metadata delimiters. It splits the input on double separators, then extracts + any postfix metadata from the last part. + + Args: + field (str): The escaped string containing doubled separators and optional postfix metadata + separator (str, optional): The separator character used for escaping. Defaults to "_" + Returns: + tuple[str, str | None]: A tuple containing: + - The unescaped string with single separators restored + - The postfix metadata if present, None otherwise + Examples: + >>> unescape_with_postfix("field1__field2__field3") + ('field1_field2_field3', None) + >>> unescape_with_postfix("field1__field2_metadata") + ('field1_field2', 'metadata') + >>> unescape_with_postfix("simple") + ('simple', None) + >>> unescape_with_postfix("field1--field2", separator="-") + ('field1-field2', None) + >>> unescape_with_postfix("field1--field2-meta", separator="-") + ('field1-field2', 'meta') + """ + + parts = field.split(separator * 2) + parts[-1], *meta = parts[-1].split("_", 1) + return separator.join(parts), meta[0] if meta else None + + def find_noncolliding_name(name: str, lut: dict) -> str: """ Generate a unique name that does not collide with existing keys in a lookup table (lut). From 51f3da283c2f7cd34fd2f2bb2abd12c43fd80901 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:54:55 +0000 Subject: [PATCH 38/57] fix: cleanup imports and fix issue in recursive structure processing --- src/orcapod/core/sources.py | 67 +++++++++++++- src/orcapod/core/streams.py | 97 +++++++++++++++++++++ src/orcapod/core/tracker.py | 73 +--------------- src/orcapod/hashing/arrow_utils.py | 4 +- src/orcapod/hashing/content_identifiable.py | 14 ++- src/orcapod/hashing/defaults.py | 7 +- src/orcapod/hashing/hash_utils.py | 2 + src/orcapod/stores/file_utils.py | 2 +- src/orcapod/types/core.py | 10 ++- src/orcapod/types/packets.py | 12 ++- src/orcapod/types/typespec_utils.py | 2 +- 11 files changed, 199 insertions(+), 91 deletions(-) diff --git a/src/orcapod/core/sources.py b/src/orcapod/core/sources.py index 3d79e7a..b1dca7d 100644 --- a/src/orcapod/core/sources.py +++ b/src/orcapod/core/sources.py @@ -3,10 +3,17 @@ from pathlib import Path from typing import Any, Literal +import polars as pl + from orcapod.core.base import Source from orcapod.hashing.legacy_core import hash_function -from orcapod.core.streams import SyncStream, SyncStreamFromGenerator -from orcapod.types import Packet, Tag +from orcapod.core.streams import ( + PolarsStream, + SyncStream, + SyncStreamFromGenerator, + StreamWrapper, +) +from orcapod.types import Packet, Tag, TypeSpec class GlobSource(Source): @@ -139,3 +146,59 @@ def claims_unique_tags( return True # Otherwise, delegate to the base class return super().claims_unique_tags(trigger_run=trigger_run) + + +class PolarsSource(Source): + def __init__( + self, + df: pl.DataFrame, + tag_keys: Collection[str], + packet_keys: Collection[str] | None = None, + ): + self.df = df + self.tag_keys = tag_keys + self.packet_keys = packet_keys + + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "PolarsSource does not support forwarding streams. " + "It generates its own stream from the DataFrame." + ) + return PolarsStream(self.df, self.tag_keys, self.packet_keys) + + +class StreamSource(Source): + def __init__(self, stream: SyncStream, **kwargs): + super().__init__(skip_tracking=True, **kwargs) + self.stream = stream + + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "StreamSource does not support forwarding streams. " + "It generates its own stream from the file system." + ) + return StreamWrapper(self.stream) + + def identity_structure(self, *streams) -> Any: + if len(streams) != 0: + raise ValueError( + "StreamSource does not support forwarding streams. " + "It generates its own stream from the file system." + ) + + return (self.__class__.__name__, self.stream) + + def types( + self, *streams: SyncStream, **kwargs + ) -> tuple[TypeSpec | None, TypeSpec | None]: + return self.stream.types() + + def keys( + self, *streams: SyncStream, **kwargs + ) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.stream.keys() + + def computed_label(self) -> str | None: + return self.stream.label diff --git a/src/orcapod/core/streams.py b/src/orcapod/core/streams.py index 243a1f4..170c80d 100644 --- a/src/orcapod/core/streams.py +++ b/src/orcapod/core/streams.py @@ -1,5 +1,7 @@ from collections.abc import Callable, Collection, Iterator +import polars as pl + from orcapod.core.base import SyncStream from orcapod.types import Packet, PacketLike, Tag, TypeSpec from copy import copy @@ -104,3 +106,98 @@ def keys( return super().keys(trigger_run=trigger_run) # If the keys are already set, return them return self.tag_keys.copy(), self.packet_keys.copy() + + +class PolarsStream(SyncStream): + def __init__( + self, + df: pl.DataFrame, + tag_keys: Collection[str], + packet_keys: Collection[str] | None = None, + ): + self.df = df + self.tag_keys = tuple(tag_keys) + self.packet_keys = tuple(packet_keys) if packet_keys is not None else None + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + df = self.df + # if self.packet_keys is not None: + # df = df.select(self.tag_keys + self.packet_keys) + for row in df.iter_rows(named=True): + tag = {key: row[key] for key in self.tag_keys} + packet = { + key: val + for key, val in row.items() + if key not in self.tag_keys and not key.startswith("_source_info_") + } + # TODO: revisit and fix this rather hacky implementation + source_info = { + key.removeprefix("_source_info_"): val + for key, val in row.items() + if key.startswith("_source_info_") + } + yield tag, Packet(packet, source_info=source_info) + + +class EmptyStream(SyncStream): + def __init__( + self, + tag_keys: Collection[str] | None = None, + packet_keys: Collection[str] | None = None, + tag_typespec: TypeSpec | None = None, + packet_typespec: TypeSpec | None = None, + ): + if tag_keys is None and tag_typespec is not None: + tag_keys = tag_typespec.keys() + self.tag_keys = list(tag_keys) if tag_keys else [] + + if packet_keys is None and packet_typespec is not None: + packet_keys = packet_typespec.keys() + self.packet_keys = list(packet_keys) if packet_keys else [] + + self.tag_typespec = tag_typespec + self.packet_typespec = packet_typespec + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.tag_keys, self.packet_keys + + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + return self.tag_typespec, self.packet_typespec + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + # Empty stream, no data to yield + return iter([]) + + +class StreamWrapper(SyncStream): + """ + A wrapper for a SyncStream that allows the stream to be labeled and + associated with an invocation without modifying the original stream. + """ + + def __init__(self, stream: SyncStream, **kwargs): + super().__init__(**kwargs) + self.stream = stream + + def keys( + self, *streams: SyncStream, **kwargs + ) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.stream.keys(*streams, **kwargs) + + def types( + self, *streams: SyncStream, **kwargs + ) -> tuple[TypeSpec | None, TypeSpec | None]: + return self.stream.types(*streams, **kwargs) + + def computed_label(self) -> str | None: + return self.stream.label + + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + """ + Iterate over the stream, yielding tuples of (tags, packets). + """ + yield from self.stream diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py index 8f07ae3..e0a2bd7 100644 --- a/src/orcapod/core/tracker.py +++ b/src/orcapod/core/tracker.py @@ -1,74 +1,5 @@ -from orcapod.core.base import Invocation, Kernel, Tracker, SyncStream, Source -from orcapod.types import Tag, Packet, TypeSpec -from collections.abc import Collection, Iterator -from typing import Any - - -class StreamWrapper(SyncStream): - """ - A wrapper for a SyncStream that allows it to be used as a Source. - This is useful for cases where you want to treat a stream as a source - without modifying the original stream. - """ - - def __init__(self, stream: SyncStream, **kwargs): - super().__init__(**kwargs) - self.stream = stream - - def keys( - self, *streams: SyncStream, **kwargs - ) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.stream.keys(*streams, **kwargs) - - def types( - self, *streams: SyncStream, **kwargs - ) -> tuple[TypeSpec | None, TypeSpec | None]: - return self.stream.types(*streams, **kwargs) - - def computed_label(self) -> str | None: - return self.stream.label - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - """ - Iterate over the stream, yielding tuples of (tags, packets). - """ - yield from self.stream - - -class StreamSource(Source): - def __init__(self, stream: SyncStream, **kwargs): - super().__init__(skip_tracking=True, **kwargs) - self.stream = stream - - def forward(self, *streams: SyncStream) -> SyncStream: - if len(streams) != 0: - raise ValueError( - "StreamSource does not support forwarding streams. " - "It generates its own stream from the file system." - ) - return StreamWrapper(self.stream) - - def identity_structure(self, *streams) -> Any: - if len(streams) != 0: - raise ValueError( - "StreamSource does not support forwarding streams. " - "It generates its own stream from the file system." - ) - - return (self.__class__.__name__, self.stream) - - def types( - self, *streams: SyncStream, **kwargs - ) -> tuple[TypeSpec | None, TypeSpec | None]: - return self.stream.types() - - def keys( - self, *streams: SyncStream, **kwargs - ) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.stream.keys() - - def computed_label(self) -> str | None: - return self.stream.label +from orcapod.core.base import Invocation, Kernel, Tracker +from orcapod.core.sources import StreamSource class GraphTracker(Tracker): diff --git a/src/orcapod/hashing/arrow_utils.py b/src/orcapod/hashing/arrow_utils.py index 168c53f..0d46cd7 100644 --- a/src/orcapod/hashing/arrow_utils.py +++ b/src/orcapod/hashing/arrow_utils.py @@ -1,7 +1,7 @@ import pyarrow as pa import json import hashlib -from typing import Dict, List, Any +from typing import Dict, Any from decimal import Decimal import base64 @@ -168,7 +168,7 @@ def _arrow_type_to_python_type(arrow_type: pa.DataType) -> str: return str(arrow_type).lower() -def _extract_semantic_metadata(field_metadata) -> Dict[str, str]: +def _extract_semantic_metadata(field_metadata) -> dict[str, str]: """ Extract only 'semantic_type' metadata from field metadata. diff --git a/src/orcapod/hashing/content_identifiable.py b/src/orcapod/hashing/content_identifiable.py index 1581e62..ce1b6c3 100644 --- a/src/orcapod/hashing/content_identifiable.py +++ b/src/orcapod/hashing/content_identifiable.py @@ -1,9 +1,19 @@ -from .types import ObjectHasher -from .defaults import get_default_object_hasher +from orcapod.hashing.types import ObjectHasher +from orcapod.hashing.defaults import get_default_object_hasher from typing import Any class ContentIdentifiableBase: + """ + Base class for content-identifiable objects. + This class provides a way to define objects that can be uniquely identified + based on their content rather than their identity in memory. Specifically, the identity of the + object is determined by the structure returned by the `identity_structure` method. + The hash of the object is computed based on the `identity_structure` using the provided `ObjectHasher`, + which defaults to the one returned by `get_default_object_hasher`. + Two content-identifiable objects are considered equal if their `identity_structure` returns the same value. + """ + def __init__( self, identity_structure_hasher: ObjectHasher | None = None, diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index a9aebcd..3bae548 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -3,16 +3,13 @@ from orcapod.hashing.types import ( LegacyCompositeFileHasher, ArrowHasher, - FileContentHasher, StringCacher, ) -from orcapod.hashing.file_hashers import BasicFileHasher, LegacyPathLikeHasherFactory +from orcapod.hashing.file_hashers import LegacyPathLikeHasherFactory from orcapod.hashing.string_cachers import InMemoryCacher from orcapod.hashing.object_hashers import ObjectHasher from orcapod.hashing.object_hashers import LegacyObjectHasher from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory -from orcapod.hashing.arrow_hashers import SemanticArrowHasher -from orcapod.hashing.semantic_type_hashers import PathHasher from orcapod.hashing.versioned_hashers import ( get_versioned_semantic_arrow_hasher, get_versioned_object_hasher, @@ -24,7 +21,7 @@ def get_default_arrow_hasher( ) -> ArrowHasher: """ Get the default Arrow hasher with semantic type support. - If `with_cache` is True, it uses an in-memory cacher for caching hash values. + If `cache_file_hash` is True, it uses an in-memory cacher for caching hash values. If a `StringCacher` is provided, it uses that for caching file hashes. """ arrow_hasher = get_versioned_semantic_arrow_hasher() if cache_file_hash: diff --git a/src/orcapod/hashing/hash_utils.py b/src/orcapod/hashing/hash_utils.py index 7fee36b..0dc0777 100644 --- a/src/orcapod/hashing/hash_utils.py +++ b/src/orcapod/hashing/hash_utils.py @@ -48,6 +48,8 @@ def process_structure( # Initialize the visited set if this is the top-level call if visited is None: visited = set() + else: + visited = visited.copy() # Copy to avoid modifying the original set # Check for circular references - use object's memory address # NOTE: While id() is not stable across sessions, we only use it within a session diff --git a/src/orcapod/stores/file_utils.py b/src/orcapod/stores/file_utils.py index 34380e0..712aada 100644 --- a/src/orcapod/stores/file_utils.py +++ b/src/orcapod/stores/file_utils.py @@ -384,7 +384,7 @@ def virtual_mount( new_packet = {} for key, value in packet.items(): - new_packet[key] = convert_pathset(value, forward_lut, reverse_lut) + new_packet[key] = convert_pathset(value, forward_lut, reverse_lut) # type: ignore return new_packet, forward_lut, reverse_lut diff --git a/src/orcapod/types/core.py b/src/orcapod/types/core.py index 12448f8..62c100d 100644 --- a/src/orcapod/types/core.py +++ b/src/orcapod/types/core.py @@ -1,6 +1,4 @@ -from typing import Protocol, Any, TypeAlias, TypeVar, Generic -import pyarrow as pa -from dataclasses import dataclass +from typing import Protocol, Any, TypeAlias import os from collections.abc import Collection, Mapping @@ -34,7 +32,11 @@ # Extended data values that can be stored in packets # Either the original PathSet or one of our supported simple data types DataValue: TypeAlias = ( - PathSet | SupportedNativePythonData | None | Collection["DataValue"] + PathSet + | SupportedNativePythonData + | None + | Collection["DataValue"] + | Mapping[str, "DataValue"] ) diff --git a/src/orcapod/types/packets.py b/src/orcapod/types/packets.py index 47df081..a5836b1 100644 --- a/src/orcapod/types/packets.py +++ b/src/orcapod/types/packets.py @@ -4,9 +4,17 @@ from orcapod.types.core import TypeSpec, Tag, TypeHandler from orcapod.types.semantic_type_registry import SemanticTypeRegistry from orcapod.types import schemas +from orcapod.types.typespec_utils import get_typespec_from_dict import pyarrow as pa -# # a packet is a mapping from string keys to data values +# A conveniece packet-like type that defines a value that can be +# converted to a packet. It's broader than Packet and a simple mapping +# from string keys to DataValue (e.g., int, float, str) can be regarded +# as PacketLike, allowing for more flexible interfaces. +# Anything that requires Packet-like data but without the strict features +# of a Packet should accept PacketLike. +# One should be careful when using PacketLike as a return type as it does not +# enforce the typespec or source_info, which are important for packet integrity. PacketLike: TypeAlias = Mapping[str, DataValue] @@ -21,8 +29,6 @@ def __init__( obj = {} super().__init__(obj) if typespec is None: - from orcapod.types.typespec_utils import get_typespec_from_dict - typespec = get_typespec_from_dict(self) self._typespec = typespec if source_info is None: diff --git a/src/orcapod/types/typespec_utils.py b/src/orcapod/types/typespec_utils.py index 4e48004..a0a3c58 100644 --- a/src/orcapod/types/typespec_utils.py +++ b/src/orcapod/types/typespec_utils.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Collection, Sequence, Mapping from typing import get_origin, get_args, Any -from .core import TypeSpec +from orcapod.types.core import TypeSpec import inspect import logging From 3d54067bb87f231431b5ae1f85e5d54ca5e3d612 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:55:40 +0000 Subject: [PATCH 39/57] refactor: add more robust arrow serialization strategy and use @ for separator in hash --- src/orcapod/hashing/arrow_hashers.py | 36 ++++++++++++++++++---------- src/orcapod/hashing/types.py | 2 +- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index 3545911..a7b5a01 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -1,11 +1,15 @@ import hashlib from typing import Any import pyarrow as pa -import pyarrow.ipc as ipc -from io import BytesIO import polars as pl import json from orcapod.hashing.types import SemanticTypeHasher, StringCacher +from orcapod.hashing import arrow_serialization_old +from collections.abc import Callable + +SERIALIZATION_METHOD_LUT: dict[str, Callable[[pa.Table], bytes]] = { + "logical": arrow_serialization_old.serialize_table_logical, +} def serialize_pyarrow_table(table: pa.Table) -> str: @@ -51,6 +55,7 @@ def __init__( chunk_size: int = 8192, handle_missing: str = "error", semantic_type_hashers: dict[str, SemanticTypeHasher] | None = None, + serialization_method: str = "logical", ): """ Initialize SemanticArrowHasher. @@ -66,6 +71,13 @@ def __init__( semantic_type_hashers or {} ) self.hash_algorithm = hash_algorithm + if serialization_method not in SERIALIZATION_METHOD_LUT: + raise ValueError( + f"Invalid serialization method '{serialization_method}'. " + f"Supported methods: {list(SERIALIZATION_METHOD_LUT.keys())}" + ) + self.serialization_method = serialization_method + self._serialize_arrow_table = SERIALIZATION_METHOD_LUT[serialization_method] def set_cacher(self, semantic_type: str, cacher: StringCacher) -> None: """ @@ -167,16 +179,16 @@ def _sort_table_columns(self, table: pa.Table) -> pa.Table: sorted_schema = pa.schema(sorted_fields) return pa.table(sorted_columns, schema=sorted_schema) - def _serialize_table_ipc(self, table: pa.Table) -> bytes: - # TODO: fix and use logical table hashing instead - """Serialize table using Arrow IPC format for stable binary representation.""" - buffer = BytesIO() + # def _serialize_table_ipc(self, table: pa.Table) -> bytes: + # # TODO: fix and use logical table hashing instead + # """Serialize table using Arrow IPC format for stable binary representation.""" + # buffer = BytesIO() - # Use IPC stream format for deterministic serialization - with ipc.new_stream(buffer, table.schema) as writer: - writer.write_table(table) + # # Use IPC stream format for deterministic serialization + # with ipc.new_stream(buffer, table.schema) as writer: + # writer.write_table(table) - return buffer.getvalue() + # return buffer.getvalue() def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: """ @@ -200,7 +212,7 @@ def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: sorted_table = pl.DataFrame(sorted_table).to_arrow() # Step 3: Serialize using Arrow IPC format - serialized_bytes = self._serialize_table_ipc(sorted_table) + serialized_bytes = self._serialize_arrow_table(sorted_table) # Step 4: Compute final hash hasher = hashlib.new(self.hash_algorithm) @@ -208,7 +220,7 @@ def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: hash_str = hasher.hexdigest() if prefix_hasher_id: - hash_str = f"{self.get_hasher_id()}:{hash_str}" + hash_str = f"{self.get_hasher_id()}@{hash_str}" return hash_str diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py index fabf812..6306d94 100644 --- a/src/orcapod/hashing/types.py +++ b/src/orcapod/hashing/types.py @@ -63,7 +63,7 @@ def hash_to_hex( ) hex_str = hex_str[:char_count] if prefix_hasher_id: - hex_str = self.get_hasher_id() + ":" + hex_str + hex_str = self.get_hasher_id() + "@" + hex_str return hex_str def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: From 1ac2be69b5f7770ff7b1d89bacd872765b343694 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:56:02 +0000 Subject: [PATCH 40/57] feat: logical serialization for arrow table --- src/orcapod/hashing/arrow_serialization.py | 1036 ++++++++++++++------ 1 file changed, 730 insertions(+), 306 deletions(-) diff --git a/src/orcapod/hashing/arrow_serialization.py b/src/orcapod/hashing/arrow_serialization.py index e4926cb..fa0500f 100644 --- a/src/orcapod/hashing/arrow_serialization.py +++ b/src/orcapod/hashing/arrow_serialization.py @@ -1,47 +1,396 @@ import pyarrow as pa +import pyarrow.compute as pc from io import BytesIO -import pyarrow.ipc as ipc import struct from typing import Any import hashlib -def serialize_table_ipc(table: pa.Table) -> bytes: - # TODO: fix and use logical table hashing instead - """Serialize table using Arrow IPC format for stable binary representation.""" - buffer = BytesIO() +def bool_sequence_to_byte(sequence: list[bool]) -> bytes: + """Convert a sequence of booleans to a byte array.""" + if len(sequence) > 8: + raise ValueError("Sequence length exceeds 8 bits, cannot fit in a byte.") + mask = 1 + flags = 0 + for value in sequence: + if value: + flags |= mask + mask <<= 1 + return struct.pack(" bytes: + """Serialize order options to bytes for inclusion in format.""" + flags = 0 + if self.ignore_column_order: + flags |= 1 + if self.ignore_row_order: + flags |= 2 + return struct.pack(" "OrderOptions": + """Deserialize order options from bytes.""" + flags = struct.unpack(" pa.Array: + """ + Convert any Arrow array to string representation for sorting purposes. + Handles all data types including complex ones. + """ + if pa.types.is_string(array.type) or pa.types.is_large_string(array.type): + # Already string + return array + + elif pa.types.is_binary(array.type) or pa.types.is_large_binary(array.type): + # Convert binary to base64 string representation for deterministic sorting + try: + # Use Arrow's base64 encoding if available + import base64 + + str_values = [] + # Get null mask + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) # Will be handled by fill_null later + else: + binary_val = array[i].as_py() + if binary_val is not None: + str_values.append(base64.b64encode(binary_val).decode("ascii")) + else: + str_values.append(None) + return pa.array(str_values, type=pa.string()) + except Exception: + # Fallback: convert to hex string + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + binary_val = array[i].as_py() + if binary_val is not None: + str_values.append(binary_val.hex()) + else: + str_values.append(None) + except Exception: + str_values.append(f"BINARY_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + binary_val = array[i].as_py() + if binary_val is not None: + str_values.append(binary_val.hex()) + else: + str_values.append(None) + except Exception: + str_values.append(f"BINARY_{i}") + return pa.array(str_values, type=pa.string()) + + elif _is_primitive_type(array.type): + # Convert primitive types to string + try: + return pc.cast(array, pa.string()) + except Exception: + # Manual conversion for types that don't cast well + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + value = array[i].as_py() + str_values.append(str(value)) + except Exception: + str_values.append(f"PRIMITIVE_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + value = array[i].as_py() + if value is not None: + str_values.append(str(value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"PRIMITIVE_{i}") + return pa.array(str_values, type=pa.string()) + + elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type): + # Convert list to string representation + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + value = array[i].as_py() + # Sort list elements for consistent representation + if value is not None: + sorted_value = sorted( + value, key=lambda x: (x is None, str(x)) + ) + str_values.append(str(sorted_value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"LIST_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + value = array[i].as_py() + if value is not None: + sorted_value = sorted(value, key=lambda x: (x is None, str(x))) + str_values.append(str(sorted_value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"LIST_{i}") + return pa.array(str_values, type=pa.string()) + + elif pa.types.is_struct(array.type): + # Convert struct to string representation + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + value = array[i].as_py() + if value is not None: + # Sort dict keys for consistent representation + if isinstance(value, dict): + sorted_items = sorted( + value.items(), key=lambda x: str(x[0]) + ) + str_values.append(str(dict(sorted_items))) + else: + str_values.append(str(value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"STRUCT_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + value = array[i].as_py() + if value is not None: + if isinstance(value, dict): + sorted_items = sorted( + value.items(), key=lambda x: str(x[0]) + ) + str_values.append(str(dict(sorted_items))) + else: + str_values.append(str(value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"STRUCT_{i}") + return pa.array(str_values, type=pa.string()) + + elif pa.types.is_dictionary(array.type): + # Convert dictionary to string representation using the decoded values + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + value = array[i].as_py() + str_values.append(str(value)) + except Exception: + str_values.append(f"DICT_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + value = array[i].as_py() + if value is not None: + str_values.append(str(value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"DICT_{i}") + return pa.array(str_values, type=pa.string()) + + else: + # Generic fallback for any other types + try: + return pc.cast(array, pa.string()) + except Exception: + # Manual conversion as last resort + str_values = [] + try: + null_mask = pc.is_null(array) # type: ignore + for i in range(len(array)): + if null_mask[i].as_py(): + str_values.append(None) + else: + try: + value = array[i].as_py() + str_values.append(str(value)) + except Exception: + str_values.append(f"UNKNOWN_{array.type}_{i}") + except Exception: + # If null checking fails, just convert all values + for i in range(len(array)): + try: + value = array[i].as_py() + if value is not None: + str_values.append(str(value)) + else: + str_values.append(None) + except Exception: + str_values.append(f"UNKNOWN_{array.type}_{i}") + return pa.array(str_values, type=pa.string()) + + +def _create_row_sort_key(table: pa.Table) -> pa.Array: + """ + Create a deterministic sort key for rows by combining all column values. + This ensures consistent row ordering regardless of input order. + """ + if table.num_rows == 0: + return pa.array([], type=pa.string()) + + # Convert each column to string representation for sorting + sort_components = [] + + for i in range(table.num_columns): + column = table.column(i) + field = table.schema.field(i) + + # Combine all chunks into a single array + if column.num_chunks > 1: + combined_array = pa.concat_arrays(column.chunks) + elif column.num_chunks == 1: + combined_array = column.chunk(0) + else: + combined_array = pa.array([], type=field.type) + # Convert to string representation for sorting + str_array = _convert_array_to_string_for_sorting(combined_array) -def serialize_table_logical(table: pa.Table) -> bytes: + # Handle nulls by replacing with a consistent null representation + str_array = pc.fill_null(str_array, "NULL") + sort_components.append(str_array) + + # Combine all columns into a single sort key + if len(sort_components) == 1: + return sort_components[0] + else: + # Concatenate all string representations with separators + separator = pa.scalar("||") + combined = sort_components[0] + for component in sort_components[1:]: + combined = pc.binary_join_element_wise(combined, separator, component) # type: ignore + return combined + + +def _sort_table_by_content(table: pa.Table) -> pa.Table: + """Sort table rows based on content for deterministic ordering.""" + if table.num_rows <= 1: + return table + + # Create sort key + sort_key = _create_row_sort_key(table) + + # Get sort indices + sort_indices = pc.sort_indices(sort_key) # type: ignore + + # Apply sort to table + return pc.take(table, sort_indices) + + +def _sort_table_columns_by_name(table: pa.Table) -> pa.Table: + """Sort table columns alphabetically by name for deterministic ordering.""" + if table.num_columns <= 1: + return table + + # Get column names and sort them + column_names = [field.name for field in table.schema] + sorted_names = sorted(column_names) + + # If already sorted, return as-is + if column_names == sorted_names: + return table + + # Reorder columns + return table.select(sorted_names) + + +def serialize_table_logical( + table: pa.Table, order_options: OrderOptions | None = None +) -> bytes: """ Serialize table using column-wise processing with direct binary data access. This implementation works directly with Arrow's underlying binary buffers without converting to Python objects, making it much faster and more memory efficient while maintaining high repeatability. + + Args: + table: PyArrow table to serialize + order_options: Options for handling column and row order independence """ + if order_options is None: + order_options = OrderOptions() + buffer = BytesIO() # Write format version - buffer.write(b"ARROW_BINARY_V1") + buffer.write(b"ARROW_BINARY_V1") # Updated version to include order options + + # Write order options + buffer.write(order_options.to_bytes()) + + # Apply ordering transformations if requested + processed_table = table + + if order_options.ignore_column_order: + processed_table = _sort_table_columns_by_name(processed_table) + + if order_options.ignore_row_order: + processed_table = _sort_table_by_content(processed_table) # Serialize schema deterministically - _serialize_schema_deterministic(buffer, table.schema) + _serialize_schema_deterministic(buffer, processed_table.schema) # Process each column using direct binary access column_digests = [] - for i in range(table.num_columns): - column = table.column(i) - field = table.schema.field(i) + for i in range(processed_table.num_columns): + column = processed_table.column(i) + field = processed_table.schema.field(i) column_digest = _serialize_column_binary(column, field) column_digests.append(column_digest) @@ -132,7 +481,7 @@ def _serialize_column_binary(column: pa.ChunkedArray, field: pa.Field) -> bytes: # Combine all chunks into a single array for consistent processing if column.num_chunks > 1: - # Multiple chunks - combine them + # Multiple chunks - combine them using pa.concat_arrays combined_array = pa.concat_arrays(column.chunks) elif column.num_chunks == 1: # Single chunk - use directly @@ -158,26 +507,37 @@ def _serialize_array_binary(array: pa.Array, data_type: pa.DataType) -> bytes: validity_buffer = array.buffers()[0] # Process based on Arrow type, accessing buffers directly - if _is_primitive_type(data_type): - _serialize_primitive_array_binary(buffer, array, data_type, validity_buffer) + try: + if _is_primitive_type(data_type): + _serialize_primitive_array_binary(buffer, array, data_type, validity_buffer) - elif pa.types.is_string(data_type) or pa.types.is_large_string(data_type): - _serialize_string_array_binary(buffer, array, data_type, validity_buffer) + elif pa.types.is_string(data_type) or pa.types.is_large_string(data_type): + _serialize_string_array_binary(buffer, array, data_type, validity_buffer) - elif pa.types.is_binary(data_type) or pa.types.is_large_binary(data_type): - _serialize_binary_array_binary(buffer, array, data_type, validity_buffer) + elif pa.types.is_binary(data_type) or pa.types.is_large_binary(data_type): + _serialize_binary_array_binary(buffer, array, data_type, validity_buffer) - elif pa.types.is_list(data_type) or pa.types.is_large_list(data_type): - _serialize_list_array_binary(buffer, array, data_type, validity_buffer) + elif pa.types.is_list(data_type) or pa.types.is_large_list(data_type): + _serialize_list_array_binary(buffer, array, data_type, validity_buffer) - elif pa.types.is_struct(data_type): - _serialize_struct_array_binary(buffer, array, data_type, validity_buffer) + elif pa.types.is_struct(data_type): + _serialize_struct_array_binary(buffer, array, data_type, validity_buffer) - elif pa.types.is_dictionary(data_type): - _serialize_dictionary_array_binary(buffer, array, data_type, validity_buffer) + elif pa.types.is_dictionary(data_type): + _serialize_dictionary_array_binary( + buffer, array, data_type, validity_buffer + ) - else: - # Fallback to element-wise processing for complex types + else: + # Fallback to element-wise processing for complex types + _serialize_array_fallback(buffer, array, data_type, validity_buffer) + + except Exception as e: + # If binary serialization fails, fall back to element-wise processing + print( + f"Warning: Binary serialization failed for {data_type}, falling back to element-wise: {e}" + ) + buffer = BytesIO() # Reset buffer _serialize_array_fallback(buffer, array, data_type, validity_buffer) return buffer.getvalue() @@ -252,17 +612,31 @@ def _serialize_list_array_binary( if offset_buffer is not None: buffer.write(offset_buffer.to_pybytes()) - # Get child array - handle both .children and .values access patterns + # Get child array - handle different access patterns child_array = None + + # Method 1: Try .values (most common for ListArray) if hasattr(array, "values") and array.values is not None: child_array = array.values - elif hasattr(array, "children") and len(array.children) > 0: + + # Method 2: Try .children (some array types) + elif hasattr(array, "children") and array.children and len(array.children) > 0: child_array = array.children[0] + # Method 3: Try accessing via flatten() for some list types + elif hasattr(array, "flatten"): + try: + child_array = array.flatten() + except Exception: + pass + # Recursively serialize child array if child_array is not None: child_data = _serialize_array_binary(child_array, data_type.value_type) buffer.write(child_data) + else: + # If we can't access child arrays directly, fall back to element-wise processing + _serialize_array_fallback(buffer, array, data_type, validity_buffer) def _serialize_struct_array_binary( @@ -272,8 +646,27 @@ def _serialize_struct_array_binary( # Write validity bitmap _serialize_validity_buffer(buffer, validity_buffer) + # Get child arrays - handle different access patterns for StructArray + child_arrays = [] + if hasattr(array, "field"): + # StructArray uses .field(i) to access child arrays + for i in range(len(data_type)): + child_arrays.append(array.field(i)) + elif hasattr(array, "children") and array.children: + # Some array types use .children + child_arrays = array.children + else: + # Fallback: try to access fields by iterating + try: + for i in range(len(data_type)): + child_arrays.append(array.field(i)) + except (AttributeError, IndexError): + # If all else fails, use element-wise processing + _serialize_array_fallback(buffer, array, data_type, validity_buffer) + return + # Serialize each child field - for i, child_array in enumerate(array.children): + for i, child_array in enumerate(child_arrays): field_type = data_type[i].type child_data = _serialize_array_binary(child_array, field_type) buffer.write(child_data) @@ -303,30 +696,6 @@ def _serialize_validity_buffer(buffer: BytesIO, validity_buffer): # If no validity buffer, there are no nulls (implicit) -def _serialize_boolean_buffer(buffer: BytesIO, data_buffer, array_length: int): - """Serialize boolean buffer (bit-packed).""" - # Boolean data is bit-packed, copy directly - bool_bytes = data_buffer.to_pybytes() - buffer.write(struct.pack(" int: - """Get byte width of primitive types.""" - if pa.types.is_boolean(data_type): - return 1 # Bit-packed, but minimum 1 byte - elif pa.types.is_integer(data_type) or pa.types.is_floating(data_type): - return data_type.bit_width // 8 - elif pa.types.is_date(data_type): - return 4 if data_type == pa.date32() else 8 - elif pa.types.is_time(data_type) or pa.types.is_timestamp(data_type): - return data_type.bit_width // 8 - else: - return 8 # Default - - def _serialize_array_fallback( buffer: BytesIO, array: pa.Array, data_type: pa.DataType, validity_buffer ): @@ -334,35 +703,132 @@ def _serialize_array_fallback( # Write validity bitmap _serialize_validity_buffer(buffer, validity_buffer) - # Process element by element (only for types that need it) + # Process element by element for i in range(len(array)): - if array.is_null(i): + try: + null_mask = pc.is_null(array) # type: ignore + is_null = null_mask[i].as_py() + except: + # Fallback null check + try: + value = array[i].as_py() + is_null = value is None + except: + is_null = False + + if is_null: buffer.write(b"\x00") else: buffer.write(b"\x01") - # For complex nested types, we might still need .as_py() - # But this should be rare with proper binary handling above - value = array[i].as_py() - _serialize_complex_value(buffer, value, data_type) + + # For complex nested types, convert to Python and serialize + try: + value = array[i].as_py() + _serialize_complex_value(buffer, value, data_type) + except Exception as e: + # If .as_py() fails, try alternative approaches + try: + # For some array types, we can access scalar values differently + scalar = array[i] + if hasattr(scalar, "value"): + value = scalar.value + else: + value = str(scalar) # Convert to string as last resort + _serialize_complex_value(buffer, value, data_type) + except Exception: + # Absolute fallback - serialize type name and index + fallback_str = f"{data_type}[{i}]" + fallback_bytes = fallback_str.encode("utf-8") + buffer.write(struct.pack(" str: +def serialize_table_logical_hash( + table: pa.Table, + algorithm: str = "sha256", + order_options: OrderOptions | None = None, +) -> str: """Create deterministic hash using binary serialization.""" - serialized = serialize_table_logical(table) + serialized = serialize_table_logical(table, order_options) if algorithm == "sha256": hasher = hashlib.sha256() @@ -377,27 +843,44 @@ def serialize_table_logical_hash(table: pa.Table, algorithm: str = "sha256") -> return hasher.hexdigest() -def serialize_table_logical_streaming(table: pa.Table) -> str: +def serialize_table_logical_streaming( + table: pa.Table, order_options: OrderOptions | None = None +) -> str: """ Memory-efficient streaming version that produces the same hash as serialize_table_logical_hash. This version processes data in streaming fashion but maintains the same logical structure as the non-streaming version to ensure identical hashes and chunking independence. """ + if order_options is None: + order_options = OrderOptions() + hasher = hashlib.sha256() # Hash format version (same as non-streaming) hasher.update(b"ARROW_BINARY_V1") + # Hash order options + hasher.update(order_options.to_bytes()) + + # Apply ordering transformations if requested + processed_table = table + + if order_options.ignore_column_order: + processed_table = _sort_table_columns_by_name(processed_table) + + if order_options.ignore_row_order: + processed_table = _sort_table_by_content(processed_table) + # Hash schema (same as non-streaming) schema_buffer = BytesIO() - _serialize_schema_deterministic(schema_buffer, table.schema) + _serialize_schema_deterministic(schema_buffer, processed_table.schema) hasher.update(schema_buffer.getvalue()) # Process each column using the same logic as non-streaming - for i in range(table.num_columns): - column = table.column(i) - field = table.schema.field(i) + for i in range(processed_table.num_columns): + column = processed_table.column(i) + field = processed_table.schema.field(i) # Use the same column serialization logic for chunking independence column_data = _serialize_column_binary(column, field) @@ -408,7 +891,46 @@ def serialize_table_logical_streaming(table: pa.Table) -> str: return hasher.hexdigest() -# Test utilities +# IPC serialization for comparison (updated to include order options for fair comparison) +def serialize_table_ipc( + table: pa.Table, order_options: OrderOptions | None = None +) -> bytes: + """Serialize table using Arrow IPC format for comparison.""" + from io import BytesIO + import pyarrow.ipc as ipc + + if order_options is None: + order_options = OrderOptions() + + buffer = BytesIO() + + # Add format version for consistency with logical serialization + buffer.write(b"ARROW_IPC_V2") + + # Add order options + buffer.write(order_options.to_bytes()) + + # Apply ordering transformations if requested + processed_table = table + + if order_options.ignore_column_order: + processed_table = _sort_table_columns_by_name(processed_table) + + if order_options.ignore_row_order: + processed_table = _sort_table_by_content(processed_table) + + # Standard IPC serialization + ipc_buffer = BytesIO() + with ipc.new_stream(ipc_buffer, processed_table.schema) as writer: + writer.write_table(processed_table) + + # Append IPC data + buffer.write(ipc_buffer.getvalue()) + + return buffer.getvalue() + + +# Test utilities (updated to test order independence) def create_test_table_1(): """Create a basic test table with various data types.""" return pa.table( @@ -494,45 +1016,6 @@ def create_test_table_different_chunking(): ) -def create_test_table_empty(): - """Create an empty table with same schema.""" - return pa.table( - { - "int32_col": pa.array([], type=pa.int32()), - "float64_col": pa.array([], type=pa.float64()), - "string_col": pa.array([], type=pa.string()), - "bool_col": pa.array([], type=pa.bool_()), - "binary_col": pa.array([], type=pa.binary()), - } - ) - - -def create_test_table_all_nulls(): - """Create a table with all null values.""" - return pa.table( - { - "int32_col": pa.array([None, None, None], type=pa.int32()), - "float64_col": pa.array([None, None, None], type=pa.float64()), - "string_col": pa.array([None, None, None], type=pa.string()), - "bool_col": pa.array([None, None, None], type=pa.bool_()), - "binary_col": pa.array([None, None, None], type=pa.binary()), - } - ) - - -def create_test_table_no_nulls(): - """Create a table with no null values.""" - return pa.table( - { - "int32_col": pa.array([1, 2, 3, 4, 5], type=pa.int32()), - "float64_col": pa.array([1.1, 2.2, 3.3, 4.4, 5.5], type=pa.float64()), - "string_col": pa.array(["hello", "world", "arrow", "fast", "data"]), - "bool_col": pa.array([True, False, True, False, True]), - "binary_col": pa.array([b"data1", b"data2", b"data3", b"data4", b"data5"]), - } - ) - - def create_test_table_complex_types(): """Create a table with complex nested types.""" return pa.table( @@ -557,29 +1040,13 @@ def create_test_table_complex_types(): ) -def create_test_table_single_column(): - """Create a table with just one column.""" - return pa.table({"single_col": pa.array([1, 2, 3, 4, 5], type=pa.int32())}) - - -def create_test_table_single_row(): - """Create a table with just one row.""" - return pa.table( - { - "int32_col": pa.array([42], type=pa.int32()), - "string_col": pa.array(["single"]), - "bool_col": pa.array([True]), - } - ) - - def run_comprehensive_tests(): - """Run comprehensive test suite for serialization.""" + """Run comprehensive test suite for serialization with order independence.""" import time - print("=" * 60) - print("COMPREHENSIVE ARROW SERIALIZATION TEST SUITE") - print("=" * 60) + print("=" * 70) + print("COMPREHENSIVE ARROW SERIALIZATION TEST SUITE (WITH ORDER OPTIONS)") + print("=" * 70) # Test cases test_cases = [ @@ -588,159 +1055,151 @@ def run_comprehensive_tests(): ("Reordered rows", create_test_table_reordered_rows), ("Different types", create_test_table_different_types), ("Different chunking", create_test_table_different_chunking), - ("Empty table", create_test_table_empty), - ("All nulls", create_test_table_all_nulls), - ("No nulls", create_test_table_no_nulls), ("Complex types", create_test_table_complex_types), - ("Single column", create_test_table_single_column), - ("Single row", create_test_table_single_row), ] - # Generate hashes for all test cases - results = {} + # Order option combinations to test + order_configs = [ + ("Default (order-sensitive)", OrderOptions(False, False)), + ("Column-order independent", OrderOptions(True, False)), + ("Row-order independent", OrderOptions(False, True)), + ("Fully order-independent", OrderOptions(True, True)), + ] - print("\n1. GENERATING HASHES FOR ALL TEST CASES") + print("\n1. ORDER INDEPENDENCE TESTS") print("-" * 50) - for name, create_func in test_cases: - try: - table = create_func() - - # Generate all hash types - logical_hash = serialize_table_logical_hash(table) - streaming_hash = serialize_table_logical_streaming(table) - ipc_hash = hashlib.sha256(serialize_table_ipc(table)).hexdigest() - - results[name] = { - "table": table, - "logical": logical_hash, - "streaming": streaming_hash, - "ipc": ipc_hash, - "rows": table.num_rows, - "cols": table.num_columns, - } - - print( - f"{name:20} | Rows: {table.num_rows:5} | Cols: {table.num_columns:2} | " - f"Logical: {logical_hash[:12]}... | IPC: {ipc_hash[:12]}..." - ) + base_table = create_test_table_1() + reordered_cols = create_test_table_reordered_columns() + reordered_rows = create_test_table_reordered_rows() - except Exception as e: - print(f"{name:20} | ERROR: {str(e)}") - results[name] = {"error": str(e)} + for config_name, order_opts in order_configs: + print(f"\n{config_name}:") + print(f" Config: {order_opts}") - print("\n2. DETERMINISM TESTS") - print("-" * 50) + # Test with base table + base_hash = serialize_table_logical_hash(base_table, order_options=order_opts) + cols_hash = serialize_table_logical_hash( + reordered_cols, order_options=order_opts + ) + rows_hash = serialize_table_logical_hash( + reordered_rows, order_options=order_opts + ) - base_table = create_test_table_1() + # Test streaming consistency + base_stream = serialize_table_logical_streaming( + base_table, order_options=order_opts + ) - # Test multiple runs of same table - logical_hashes = [serialize_table_logical_hash(base_table) for _ in range(5)] - streaming_hashes = [serialize_table_logical_streaming(base_table) for _ in range(5)] - ipc_hashes = [ - hashlib.sha256(serialize_table_ipc(base_table)).hexdigest() for _ in range(5) - ] + print(f" Base table: {base_hash[:12]}...") + print(f" Reordered columns: {cols_hash[:12]}...") + print(f" Reordered rows: {rows_hash[:12]}...") + print(f" Streaming matches: {base_hash == base_stream}") - print( - f"Logical deterministic: {len(set(logical_hashes)) == 1} ({len(set(logical_hashes))}/5 unique)" - ) - print( - f"Streaming deterministic: {len(set(streaming_hashes)) == 1} ({len(set(streaming_hashes))}/5 unique)" - ) - print( - f"IPC deterministic: {len(set(ipc_hashes)) == 1} ({len(set(ipc_hashes))}/5 unique)" - ) - print(f"Streaming == Logical: {streaming_hashes[0] == logical_hashes[0]}") + # Check expected behavior + cols_should_match = order_opts.ignore_column_order + rows_should_match = order_opts.ignore_row_order + + cols_match = base_hash == cols_hash + rows_match = base_hash == rows_hash + + cols_status = "✓" if cols_match == cols_should_match else "✗" + rows_status = "✓" if rows_match == rows_should_match else "✗" + + print( + f" {cols_status} Column order independence: {cols_match} (expected: {cols_should_match})" + ) + print( + f" {rows_status} Row order independence: {rows_match} (expected: {rows_should_match})" + ) - print("\n3. EQUIVALENCE TESTS") + print("\n2. CHUNKING INDEPENDENCE WITH ORDER OPTIONS") print("-" * 50) - base_logical = results["Basic table"]["logical"] - base_ipc = results["Basic table"]["ipc"] - - equivalence_tests = [ - ( - "Same table vs reordered columns", - "Reordered columns", - False, - "Different column order should produce different hash", - ), - ( - "Same table vs reordered rows", - "Reordered rows", - False, - "Different row order should produce different hash", - ), - ( - "Same table vs different types", - "Different types", - False, - "Different data types should produce different hash", - ), - ( - "Same table vs different chunking", - "Different chunking", - True, - "Same data with different chunking should produce same hash", - ), - ( - "Same table vs no nulls", - "No nulls", - False, - "Different null patterns should produce different hash", - ), - ( - "Same table vs all nulls", - "All nulls", - False, - "Different data should produce different hash", - ), - ] + original = create_test_table_1() + combined = original.combine_chunks() + different_chunking = create_test_table_different_chunking() + + for config_name, order_opts in order_configs: + orig_hash = serialize_table_logical_hash(original, order_options=order_opts) + comb_hash = serialize_table_logical_hash(combined, order_options=order_opts) + diff_hash = serialize_table_logical_hash( + different_chunking, order_options=order_opts + ) + + chunking_independent = orig_hash == comb_hash == diff_hash + status = "✓" if chunking_independent else "✗" + + print( + f"{status} {config_name:25} | Chunking independent: {chunking_independent}" + ) + + print("\n3. FORMAT VERSION COMPATIBILITY") + print("-" * 50) - for test_name, compare_case, should_match, explanation in equivalence_tests: - if compare_case in results and "logical" in results[compare_case]: - compare_logical = results[compare_case]["logical"] - compare_ipc = results[compare_case]["ipc"] + # Test that different order options produce different hashes when they should + test_table = create_test_table_1() - logical_match = base_logical == compare_logical - ipc_match = base_ipc == compare_ipc + hashes = {} + for config_name, order_opts in order_configs: + hash_value = serialize_table_logical_hash(test_table, order_options=order_opts) + hashes[config_name] = hash_value + print(f"{config_name:25} | {hash_value[:16]}...") - logical_status = "✓" if logical_match == should_match else "✗" - ipc_status = "✓" if ipc_match == should_match else "✗" + # Verify that order-sensitive vs order-independent produce different hashes + default_hash = hashes["Default (order-sensitive)"] + col_indep_hash = hashes["Column-order independent"] + row_indep_hash = hashes["Row-order independent"] + full_indep_hash = hashes["Fully order-independent"] - print(f"{logical_status} {test_name}") - print(f" Logical: {logical_match} (expected: {should_match})") - print(f" IPC: {ipc_match} (expected: {should_match})") - print(f" Reason: {explanation}") - print() + print(f"\nHash uniqueness:") + print(f" Default != Col-independent: {default_hash != col_indep_hash}") + print(f" Default != Row-independent: {default_hash != row_indep_hash}") + print(f" Default != Fully independent: {default_hash != full_indep_hash}") - print("4. CHUNKING INDEPENDENCE DETAILED TEST") + print("\n4. CONTENT EQUIVALENCE TEST") print("-" * 50) - # Test various chunking strategies - original_table = create_test_table_1() - combined_table = original_table.combine_chunks() - different_chunking = create_test_table_different_chunking() + # Create tables with same content but different presentation + table_a = pa.table({"col1": pa.array([1, 2, 3]), "col2": pa.array(["a", "b", "c"])}) - orig_logical = serialize_table_logical_hash(original_table) - comb_logical = serialize_table_logical_hash(combined_table) - diff_logical = serialize_table_logical_hash(different_chunking) + table_b = pa.table( + { + "col2": pa.array(["a", "b", "c"]), # Different column order + "col1": pa.array([1, 2, 3]), + } + ) - orig_ipc = hashlib.sha256(serialize_table_ipc(original_table)).hexdigest() - comb_ipc = hashlib.sha256(serialize_table_ipc(combined_table)).hexdigest() - diff_ipc = hashlib.sha256(serialize_table_ipc(different_chunking)).hexdigest() + table_c = pa.table( + { + "col1": pa.array([3, 1, 2]), # Different row order + "col2": pa.array(["c", "a", "b"]), + } + ) - print(f"Original chunking: {orig_logical[:16]}...") - print(f"Combined chunks: {comb_logical[:16]}...") - print(f"Different chunking: {diff_logical[:16]}...") - print( - f"Logical chunking-independent: {orig_logical == comb_logical == diff_logical}" + table_d = pa.table( + { + "col2": pa.array(["c", "a", "b"]), # Both different + "col1": pa.array([3, 1, 2]), + } ) - print() - print(f"Original IPC: {orig_ipc[:16]}...") - print(f"Combined IPC: {comb_ipc[:16]}...") - print(f"Different IPC: {diff_ipc[:16]}...") - print(f"IPC chunking-independent: {orig_ipc == comb_ipc == diff_ipc}") + + full_indep_opts = OrderOptions(True, True) + + hash_a = serialize_table_logical_hash(table_a, order_options=full_indep_opts) + hash_b = serialize_table_logical_hash(table_b, order_options=full_indep_opts) + hash_c = serialize_table_logical_hash(table_c, order_options=full_indep_opts) + hash_d = serialize_table_logical_hash(table_d, order_options=full_indep_opts) + + all_match = hash_a == hash_b == hash_c == hash_d + status = "✓" if all_match else "✗" + + print(f"{status} Content equivalence test:") + print(f" Table A (original): {hash_a[:12]}...") + print(f" Table B (reord cols): {hash_b[:12]}...") + print(f" Table C (reord rows): {hash_c[:12]}...") + print(f" Table D (both reord): {hash_d[:12]}...") + print(f" All hashes match: {all_match}") print("\n5. PERFORMANCE COMPARISON") print("-" * 50) @@ -758,18 +1217,14 @@ def run_comprehensive_tests(): } ) - # Time each method - methods = [ - ("Logical", lambda t: serialize_table_logical_hash(t)), - ("Streaming", lambda t: serialize_table_logical_streaming(t)), - ("IPC", lambda t: hashlib.sha256(serialize_table_ipc(t)).hexdigest()), - ] - hash_result = "" - for method_name, method_func in methods: + # Time each method with different order options + for config_name, order_opts in order_configs: times = [] for _ in range(3): # Run 3 times for average start = time.time() - hash_result = method_func(large_table) + hash_result = serialize_table_logical_hash( + large_table, order_options=order_opts + ) end = time.time() times.append(end - start) @@ -777,43 +1232,12 @@ def run_comprehensive_tests(): throughput = (large_size * 4) / avg_time # 4 columns print( - f"{method_name:10} | {avg_time * 1000:6.1f}ms | {throughput:8.0f} values/sec | {hash_result[:12]}..." + f"{config_name:25} | {avg_time * 1000:6.1f}ms | {throughput:8.0f} values/sec" ) - print("\n6. EDGE CASES") - print("-" * 50) - - edge_cases = ["Empty table", "All nulls", "Single column", "Single row"] - for case in edge_cases: - if case in results and "error" not in results[case]: - r = results[case] - print( - f"{case:15} | {r['rows']:3}r x {r['cols']:2}c | " - f"L:{r['logical'][:8]}... | I:{r['ipc'][:8]}... | " - f"Match: {r['logical'] == r['streaming']}" - ) - - print("\n7. COMPLEX TYPES TEST") - print("-" * 50) - - if "Complex types" in results and "error" not in results["Complex types"]: - complex_result = results["Complex types"] - print(f"Complex types serialization successful:") - print(f" Logical hash: {complex_result['logical']}") - print( - f" Streaming ==: {complex_result['logical'] == complex_result['streaming']}" - ) - print(f" Rows/Cols: {complex_result['rows']}r x {complex_result['cols']}c") - else: - print( - "Complex types test failed - this is expected for some complex nested types" - ) - - print(f"\n{'=' * 60}") - print("TEST SUITE COMPLETE") - print(f"{'=' * 60}") - - return results + print(f"\n{'=' * 70}") + print("ORDER-INDEPENDENT SERIALIZATION TEST SUITE COMPLETE") + print(f"{'=' * 70}") # Main execution From dab33787cdd0a88b382a89d462e8ae5d92d71dcb Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:56:51 +0000 Subject: [PATCH 41/57] feat: update versioned arrow hasher to use new serialization --- src/orcapod/hashing/object_hashers.py | 21 ++- src/orcapod/hashing/versioned_hashers.py | 5 +- src/orcapod/pipeline/nodes.py | 213 ++++++++++------------- 3 files changed, 110 insertions(+), 129 deletions(-) diff --git a/src/orcapod/hashing/object_hashers.py b/src/orcapod/hashing/object_hashers.py index bdd0169..3401574 100644 --- a/src/orcapod/hashing/object_hashers.py +++ b/src/orcapod/hashing/object_hashers.py @@ -1,7 +1,6 @@ -from polars import Object -from .types import FunctionInfoExtractor, ObjectHasher -from .legacy_core import legacy_hash -from .hash_utils import hash_object +from orcapod.hashing.types import FunctionInfoExtractor, ObjectHasher +from orcapod.hashing import legacy_core +from orcapod.hashing import hash_utils class BasicObjectHasher(ObjectHasher): @@ -30,7 +29,9 @@ def hash(self, obj: object) -> bytes: Returns: bytes: The byte representation of the hash. """ - return hash_object(obj, function_info_extractor=self.function_info_extractor) + return hash_utils.hash_object( + obj, function_info_extractor=self.function_info_extractor + ) class LegacyObjectHasher(ObjectHasher): @@ -54,6 +55,12 @@ def __init__( """ self.function_info_extractor = function_info_extractor + def get_hasher_id(self) -> str: + """ + Returns a unique identifier/name assigned to the hasher + """ + return "legacy_object_hasher" + def hash(self, obj: object) -> bytes: """ Hash an object to a byte representation. @@ -64,4 +71,6 @@ def hash(self, obj: object) -> bytes: Returns: bytes: The byte representation of the hash. """ - return legacy_hash(obj, function_info_extractor=self.function_info_extractor) + return legacy_core.legacy_hash( + obj, function_info_extractor=self.function_info_extractor + ) diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py index e6095a0..d2fec4d 100644 --- a/src/orcapod/hashing/versioned_hashers.py +++ b/src/orcapod/hashing/versioned_hashers.py @@ -1,6 +1,6 @@ # A collection of versioned hashers that provide a "default" implementation of hashers. from .arrow_hashers import SemanticArrowHasher -from .types import ObjectHasher, ArrowHasher +from .types import ObjectHasher import importlib from typing import Any @@ -13,6 +13,7 @@ "hasher_id": "arrow_v0.1", "hash_algorithm": "sha256", "chunk_size": 8192, + "serialization_method": "logical", "semantic_type_hashers": { "path": { "_class": "orcapod.hashing.semantic_type_hashers.PathHasher", @@ -65,7 +66,7 @@ def parse_objectspec(obj_spec: dict) -> Any: def get_versioned_semantic_arrow_hasher( version: str | None = None, -) -> ArrowHasher: +) -> SemanticArrowHasher: """ Get the versioned hasher for the specified version. diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index b5bd54e..fabb664 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -1,19 +1,16 @@ from orcapod.core.pod import Pod, FunctionPod from orcapod.core import SyncStream, Source, Kernel +from orcapod.core.streams import PolarsStream +from orcapod.core.streams import EmptyStream from orcapod.stores import ArrowDataStore from orcapod.types import Tag, Packet, PacketLike, TypeSpec, default_registry -from orcapod.types.typespec_utils import ( - get_typespec_from_dict, - union_typespecs, - extract_function_typespecs, -) +from orcapod.types.typespec_utils import union_typespecs from orcapod.types.semantic_type_registry import SemanticTypeRegistry from orcapod.types import packets, schemas from orcapod.hashing import ObjectHasher, ArrowHasher from orcapod.hashing.defaults import get_default_object_hasher, get_default_arrow_hasher from typing import Any, Literal from collections.abc import Collection, Iterator -import pyarrow as pa import polars as pl from orcapod.core.streams import SyncStreamFromGenerator @@ -26,91 +23,6 @@ def get_tag_typespec(tag: Tag) -> dict[str, type]: return {k: str for k in tag} -class PolarsSource(Source): - def __init__( - self, - df: pl.DataFrame, - tag_keys: Collection[str], - packet_keys: Collection[str] | None = None, - ): - self.df = df - self.tag_keys = tag_keys - self.packet_keys = packet_keys - - def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: - if len(streams) != 0: - raise ValueError( - "PolarsSource does not support forwarding streams. " - "It generates its own stream from the DataFrame." - ) - return PolarsStream(self.df, self.tag_keys, self.packet_keys) - - -class PolarsStream(SyncStream): - def __init__( - self, - df: pl.DataFrame, - tag_keys: Collection[str], - packet_keys: Collection[str] | None = None, - ): - self.df = df - self.tag_keys = tuple(tag_keys) - self.packet_keys = tuple(packet_keys) if packet_keys is not None else None - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - df = self.df - # if self.packet_keys is not None: - # df = df.select(self.tag_keys + self.packet_keys) - for row in df.iter_rows(named=True): - tag = {key: row[key] for key in self.tag_keys} - packet = { - key: val - for key, val in row.items() - if key not in self.tag_keys and not key.startswith("_source_info_") - } - # TODO: revisit and fix this rather hacky implementation - source_info = { - key.removeprefix("_source_info_"): val - for key, val in row.items() - if key.startswith("_source_info_") - } - yield tag, Packet(packet, source_info=source_info) - - -class EmptyStream(SyncStream): - def __init__( - self, - tag_keys: Collection[str] | None = None, - packet_keys: Collection[str] | None = None, - tag_typespec: TypeSpec | None = None, - packet_typespec: TypeSpec | None = None, - ): - if tag_keys is None and tag_typespec is not None: - tag_keys = tag_typespec.keys() - self.tag_keys = list(tag_keys) if tag_keys else [] - - if packet_keys is None and packet_typespec is not None: - packet_keys = packet_typespec.keys() - self.packet_keys = list(packet_keys) if packet_keys else [] - - self.tag_typespec = tag_typespec - self.packet_typespec = packet_typespec - - def keys( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.tag_keys, self.packet_keys - - def types( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[TypeSpec | None, TypeSpec | None]: - return self.tag_typespec, self.packet_typespec - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - # Empty stream, no data to yield - return iter([]) - - class KernelInvocationWrapper(Kernel): def __init__( self, kernel: Kernel, input_streams: Collection[SyncStream], **kwargs @@ -119,10 +31,10 @@ def __init__( self.kernel = kernel self.input_streams = list(input_streams) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.kernel!r}>" - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}<{self.kernel}>" def computed_label(self) -> str | None: @@ -187,7 +99,7 @@ def __init__( kernel: Kernel, input_streams: Collection[SyncStream], output_store: ArrowDataStore, - store_path_prefix: tuple[str, ...] | None = None, + store_path_prefix: tuple[str, ...] = (), kernel_hasher: ObjectHasher | None = None, arrow_packet_hasher: ArrowHasher | None = None, packet_type_registry: SemanticTypeRegistry | None = None, @@ -196,7 +108,7 @@ def __init__( super().__init__(kernel, input_streams, **kwargs) self.output_store = output_store - self.store_path_prefix = store_path_prefix or () + self.store_path_prefix = store_path_prefix # These are configurable but are not expected to be modified except for special circumstances if kernel_hasher is None: @@ -235,10 +147,27 @@ def kernel_hasher(self, kernel_hasher: ObjectHasher | None = None): # hasher changed -- trigger recomputation of properties that depend on kernel hasher self.update_cached_values() + @property + def source_info(self) -> tuple[str, ...]: + """ + Returns a tuple of (label, kernel_hash) that uniquely identifies the source of the cached outputs. + This is used to store and retrieve the outputs from the output store. + """ + return self.label, self.kernel_hasher.hash_to_hex( + self.kernel, prefix_hasher_id=True + ) + + @property + def store_path(self) -> tuple[str, ...]: + """ + Returns the path prefix for the output store. + This is used to store and retrieve the outputs from the output store. + """ + return self.store_path_prefix + self.source_info + def update_cached_values(self): - self.source_info = self.store_path_prefix + ( - self.label, - self.kernel_hasher.hash_to_hex(self.kernel, prefix_hasher_id=True), + self.kernel_hash = self.kernel_hasher.hash_to_hex( + self.kernel, prefix_hasher_id=True ) self.tag_keys, self.packet_keys = self.keys(trigger_run=False) self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) @@ -269,7 +198,6 @@ def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: raise ValueError( "CachedKernelWrapper has no tag keys defined, cannot return PolarsStream" ) - source_info_sig = ":".join(self.source_info) return PolarsStream( self.df, tag_keys=self.tag_keys, packet_keys=self.packet_keys ) @@ -304,9 +232,9 @@ def post_call(self, tag: Tag, packet: Packet) -> None: ) # TODO: revisit this logic output_id = self.arrow_hasher.hash_table(output_table, prefix_hasher_id=True) - if not self.output_store.get_record(self.source_info, output_id): + if not self.output_store.get_record(self.store_path, output_id): self.output_store.add_record( - self.source_info, + self.store_path, output_id, output_table, ) @@ -320,7 +248,7 @@ def output_iterator_completion_hook(self) -> None: @property def lazy_df(self) -> pl.LazyFrame | None: - return self.output_store.get_all_records_as_polars(self.source_info) + return self.output_store.get_all_records_as_polars(self.store_path) @property def df(self) -> pl.DataFrame | None: @@ -378,7 +306,9 @@ def __init__( output_store: ArrowDataStore, tag_store: ArrowDataStore | None = None, label: str | None = None, - store_path_prefix: tuple[str, ...] | None = None, + store_path_prefix: tuple[str, ...] = (), + output_store_path_prefix: tuple[str, ...] = (), + tag_store_path_prefix: tuple[str, ...] = (), skip_memoization_lookup: bool = False, skip_memoization: bool = False, skip_tag_record: bool = False, @@ -395,7 +325,9 @@ def __init__( error_handling=error_handling, **kwargs, ) - self.store_path_prefix = store_path_prefix or () + self.output_store_path_prefix = store_path_prefix + output_store_path_prefix + self.tag_store_path_prefix = store_path_prefix + tag_store_path_prefix + self.output_store = output_store self.tag_store = tag_store @@ -419,6 +351,18 @@ def __init__( self.update_cached_values() self._cache_computed = False + @property + def tag_keys(self) -> tuple[str, ...]: + if self._tag_keys is None: + raise ValueError("Tag keys are not set, cannot return tag keys") + return self._tag_keys + + @property + def output_keys(self) -> tuple[str, ...]: + if self._output_keys is None: + raise ValueError("Output keys are not set, cannot return output keys") + return self._output_keys + @property def object_hasher(self) -> ObjectHasher: return self._object_hasher @@ -459,17 +403,19 @@ def update_cached_values(self) -> None: self.function_pod_hash = self.object_hasher.hash_to_hex( self.function_pod, prefix_hasher_id=True ) + self.node_hash = self.object_hasher.hash_to_hex(self, prefix_hasher_id=True) self.input_typespec, self.output_typespec = ( self.function_pod.get_function_typespecs() ) - self.tag_keys, self.output_keys = self.keys(trigger_run=False) + tag_keys, output_keys = self.keys(trigger_run=False) - if self.tag_keys is None or self.output_keys is None: + if tag_keys is None or output_keys is None: raise ValueError( "Currently, cached function pod wrapper can only work with function pods that have keys defined." ) - self.tag_keys = tuple(self.tag_keys) - self.output_keys = tuple(self.output_keys) + self._tag_keys = tuple(tag_keys) + self._output_keys = tuple(output_keys) + self.tag_typespec, self.output_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.output_typespec is None: raise ValueError( @@ -532,9 +478,29 @@ def get_packet_key(self, packet: Packet) -> str: ) @property - def source_info(self): + def pod_source_info(self): return self.function_pod.function_name, self.function_pod_hash + @property + def node_source_info(self): + return self.label, self.node_hash + + @property + def output_store_path(self) -> tuple[str, ...]: + """ + Returns the path prefix for the output store. + This is used to store and retrieve the outputs from the output store. + """ + return self.output_store_path_prefix + self.pod_source_info + + @property + def tag_store_path(self) -> tuple[str, ...]: + """ + Returns the path prefix for the tag store. + This is used to store and retrieve the tags associated with memoized packets. + """ + return self.tag_store_path_prefix + self.node_source_info + def is_memoized(self, packet: Packet) -> bool: return self.retrieve_memoized(packet) is not None @@ -566,9 +532,9 @@ def _add_pipeline_record_with_packet_key( # TODO: add error handling # check if record already exists: - retrieved_table = self.tag_store.get_record(self.source_info, entry_hash) + retrieved_table = self.tag_store.get_record(self.tag_store_path, entry_hash) if retrieved_table is None: - self.tag_store.add_record(self.source_info, entry_hash, table) + self.tag_store.add_record(self.tag_store_path, entry_hash, table) return tag @@ -587,7 +553,7 @@ def _retrieve_memoized_with_packet_key(self, packet_key: str) -> Packet | None: """ logger.debug(f"Retrieving memoized packet with key {packet_key}") arrow_table = self.output_store.get_record( - self.source_info, + self.output_store_path, packet_key, ) if arrow_table is None: @@ -625,7 +591,7 @@ def _memoize_with_packet_key( # consider simpler alternative packets = self.output_converter.from_arrow_table_to_python_packets( self.output_store.add_record( - self.source_info, + self.output_store_path, packet_key, self.output_converter.from_python_packet_to_arrow_table(output_packet), ) @@ -668,7 +634,7 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: # e.g. if the output is a file, the path may be changed # add source info to the output packet source_info = { - k: "-".join(self.source_info) + "-" + packet_key + k: "-".join(self.pod_source_info) + "-" + packet_key + ":" + str(k) for k in output_packet.source_info } # TODO: fix and make this not access protected field directly @@ -691,12 +657,12 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet | None]: return tag, output_packet def get_all_outputs(self) -> pl.LazyFrame | None: - return self.output_store.get_all_records_as_polars(self.source_info) + return self.output_store.get_all_records_as_polars(self.output_store_path) def get_all_tags(self, with_packet_id: bool = False) -> pl.LazyFrame | None: if self.tag_store is None: raise ValueError("Tag store is not set, no tag record can be retrieved") - data = self.tag_store.get_all_records_as_polars(self.source_info) + data = self.tag_store.get_all_records_as_polars(self.tag_store_path) if not with_packet_id: return data.drop("__packet_key") if data is not None else None return data @@ -711,11 +677,11 @@ def get_all_entries_with_tags( if self.tag_store is None: raise ValueError("Tag store is not set, no tag record can be retrieved") - tag_records = self.tag_store.get_all_records_as_polars(self.source_info) + tag_records = self.tag_store.get_all_records_as_polars(self.tag_store_path) if tag_records is None: return None result_packets = self.output_store.get_records_by_ids_as_polars( - self.source_info, + self.output_store_path, tag_records.collect()["__packet_key"], preserve_input_order=True, ) @@ -790,14 +756,19 @@ class DummyCachedFunctionPod(CachedFunctionPodWrapper): """ def __init__(self, source_pod: CachedFunctionPodWrapper): - self._source_info = source_pod.source_info + self._pod_source_info = source_pod.pod_source_info + self._node_source_info = source_pod.node_source_info self.output_store = source_pod.output_store self.tag_store = source_pod.tag_store self.function_pod = DummyFunctionPod(source_pod.function_pod.function_name) @property - def source_info(self) -> tuple[str, str]: - return self._source_info + def pod_source_info(self) -> tuple[str, str]: + return self._pod_source_info + + @property + def node_source_info(self) -> tuple[str, str]: + return self._node_source_info class Node(KernelInvocationWrapper, Source): From 4f079270a136656eca3f99bc6577d1d35b28e156 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:57:10 +0000 Subject: [PATCH 42/57] wip: delta table store implementation --- .../stores/delta_table_arrow_data_store.py | 683 +++++++++++++++--- 1 file changed, 598 insertions(+), 85 deletions(-) diff --git a/src/orcapod/stores/delta_table_arrow_data_store.py b/src/orcapod/stores/delta_table_arrow_data_store.py index c05dea9..e5ddfb9 100644 --- a/src/orcapod/stores/delta_table_arrow_data_store.py +++ b/src/orcapod/stores/delta_table_arrow_data_store.py @@ -1,10 +1,14 @@ import pyarrow as pa +import pyarrow.compute as pc import polars as pl from pathlib import Path -from typing import Any, Union +from typing import Any, Dict, List import logging from deltalake import DeltaTable, write_deltalake from deltalake.exceptions import TableNotFoundError +import threading +from collections import defaultdict +import json # Module-level logger logger = logging.getLogger(__name__) @@ -12,7 +16,7 @@ class DeltaTableArrowDataStore: """ - Delta Table-based Arrow data store with flexible hierarchical path support. + Delta Table-based Arrow data store with flexible hierarchical path support and schema preservation. Uses tuple-based source paths for robust parameter handling: - ("source_name", "source_id") -> source_name/source_id/ @@ -26,6 +30,8 @@ def __init__( duplicate_entry_behavior: str = "error", create_base_path: bool = True, max_hierarchy_depth: int = 10, + batch_size: int = 100, + auto_flush_interval: float = 300.0, # 5 minutes ): """ Initialize the DeltaTableArrowDataStore. @@ -37,6 +43,8 @@ def __init__( - 'overwrite': Replace existing entry with new data create_base_path: Whether to create the base path if it doesn't exist max_hierarchy_depth: Maximum allowed depth for source paths (safety limit) + batch_size: Number of records to batch before writing to Delta table + auto_flush_interval: Time in seconds to auto-flush pending batches (0 to disable) """ # Validate duplicate behavior if duplicate_entry_behavior not in ["error", "overwrite"]: @@ -45,6 +53,8 @@ def __init__( self.duplicate_entry_behavior = duplicate_entry_behavior self.base_path = Path(base_path) self.max_hierarchy_depth = max_hierarchy_depth + self.batch_size = batch_size + self.auto_flush_interval = auto_flush_interval if create_base_path: self.base_path.mkdir(parents=True, exist_ok=True) @@ -56,11 +66,55 @@ def __init__( # Cache for Delta tables to avoid repeated initialization self._delta_table_cache: dict[str, DeltaTable] = {} + # Cache for original schemas (without __entry_id column) + self._schema_cache: dict[str, pa.Schema] = {} + + # Batch management + self._pending_batches: Dict[str, List[pa.Table]] = defaultdict(list) + self._batch_lock = threading.Lock() + + # Auto-flush timer + self._flush_timer = None + # if auto_flush_interval > 0: + # self._start_auto_flush_timer() + logger.info( f"Initialized DeltaTableArrowDataStore at {self.base_path} " - f"with duplicate_entry_behavior='{duplicate_entry_behavior}'" + f"with duplicate_entry_behavior='{duplicate_entry_behavior}', " + f"batch_size={batch_size}, auto_flush_interval={auto_flush_interval}s" ) + def _start_auto_flush_timer(self): + """Start the auto-flush timer.""" + if self._flush_timer: + self._flush_timer.cancel() + + if self.auto_flush_interval > 0: + self._flush_timer = threading.Timer( + self.auto_flush_interval, self._auto_flush + ) + self._flush_timer.daemon = True + self._flush_timer.start() + + def _auto_flush(self): + """Auto-flush all pending batches.""" + try: + print("Flushing!", flush=True) + self.flush_all_batches() + except Exception as e: + logger.error(f"Error during auto-flush: {e}") + finally: + self._start_auto_flush_timer() + + def __del__(self): + """Cleanup when object is destroyed.""" + try: + if self._flush_timer: + self._flush_timer.cancel() + self.flush_all_batches() + except Exception: + pass # Ignore errors during cleanup + def _validate_source_path(self, source_path: tuple[str, ...]) -> None: """ Validate source path components. @@ -104,6 +158,204 @@ def _get_table_path(self, source_path: tuple[str, ...]) -> Path: path = path / component return path + def _get_schema_metadata_path(self, source_path: tuple[str, ...]) -> Path: + """Get the path for storing original schema metadata.""" + table_path = self._get_table_path(source_path) + return table_path / "_original_schema.json" + + def _save_original_schema( + self, source_path: tuple[str, ...], schema: pa.Schema + ) -> None: + """Save the original schema (without __entry_id) to metadata file.""" + source_key = self._get_source_key(source_path) + + # Cache the schema + self._schema_cache[source_key] = schema + + try: + # Save to file as well for persistence + schema_path = self._get_schema_metadata_path(source_path) + schema_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert schema to JSON-serializable format + def convert_metadata(metadata): + """Convert Arrow metadata (bytes keys/values) to JSON-safe format.""" + if metadata is None: + return None + result = {} + for key, value in metadata.items(): + # Convert bytes keys and values to strings + str_key = ( + key.decode("utf-8") if isinstance(key, bytes) else str(key) + ) + str_value = ( + value.decode("utf-8") + if isinstance(value, bytes) + else str(value) + ) + result[str_key] = str_value + return result + + schema_dict = { + "fields": [ + { + "name": field.name, + "type": str(field.type), + "nullable": field.nullable, + "metadata": convert_metadata(field.metadata), + } + for field in schema + ], + "metadata": convert_metadata(schema.metadata), + } + + with open(schema_path, "w") as f: + json.dump(schema_dict, f, indent=2) + + except Exception as e: + logger.warning(f"Could not save schema metadata for {source_key}: {e}") + + def _load_original_schema(self, source_path: tuple[str, ...]) -> pa.Schema | None: + """Load the original schema from cache or metadata file.""" + source_key = self._get_source_key(source_path) + + # Check cache first + if source_key in self._schema_cache: + return self._schema_cache[source_key] + + # Try to load from file + try: + schema_path = self._get_schema_metadata_path(source_path) + if not schema_path.exists(): + return None + + with open(schema_path, "r") as f: + schema_dict = json.load(f) + + # Reconstruct schema from JSON + def convert_metadata_back(metadata_dict): + """Convert JSON metadata back to Arrow format (bytes keys/values).""" + if metadata_dict is None: + return None + result = {} + for key, value in metadata_dict.items(): + # Convert string keys and values back to bytes + bytes_key = key.encode("utf-8") + bytes_value = ( + value.encode("utf-8") + if isinstance(value, str) + else str(value).encode("utf-8") + ) + result[bytes_key] = bytes_value + return result + + fields = [] + for field_dict in schema_dict["fields"]: + # Parse the type string back to Arrow type + type_str = field_dict["type"] + arrow_type = self._parse_arrow_type_string(type_str) + + metadata = convert_metadata_back(field_dict.get("metadata")) + + field = pa.field( + field_dict["name"], + arrow_type, + nullable=field_dict["nullable"], + metadata=metadata, + ) + fields.append(field) + + schema_metadata = convert_metadata_back(schema_dict.get("metadata")) + + schema = pa.schema(fields, metadata=schema_metadata) + + # Cache it + self._schema_cache[source_key] = schema + return schema + + except Exception as e: + logger.warning(f"Could not load schema metadata for {source_key}: {e}") + return None + + def _parse_arrow_type_string(self, type_str: str) -> pa.DataType: + """Parse Arrow type string back to Arrow type object.""" + # This is a simplified parser for common types + # You might need to extend this for more complex types + type_str = type_str.strip() + + # Handle basic types + if type_str == "int64": + return pa.int64() + elif type_str == "int32": + return pa.int32() + elif type_str == "float64": + return pa.float64() + elif type_str == "float32": + return pa.float32() + elif type_str == "bool": + return pa.bool_() + elif type_str == "string": + return pa.string() + elif type_str == "large_string": + return pa.large_string() + elif type_str == "binary": + return pa.binary() + elif type_str == "large_binary": + return pa.large_binary() + elif type_str.startswith("timestamp"): + # Extract timezone if present + if "[" in type_str and "]" in type_str: + tz = type_str.split("[")[1].split("]")[0] + if tz == "UTC": + tz = "UTC" + return pa.timestamp("us", tz=tz) + else: + return pa.timestamp("us") + elif type_str.startswith("list<"): + # Parse list type + inner_type_str = type_str[5:-1] # Remove 'list<' and '>' + inner_type = self._parse_arrow_type_string(inner_type_str) + return pa.list_(inner_type) + else: + # Fallback to string for unknown types + logger.warning(f"Unknown Arrow type string: {type_str}, using string") + return pa.string() + + def _get_or_create_delta_table( + self, source_path: tuple[str, ...] + ) -> DeltaTable | None: + """ + Get or create a Delta table, handling schema initialization properly. + + Args: + source_path: Tuple of path components + + Returns: + DeltaTable instance or None if table doesn't exist + """ + source_key = self._get_source_key(source_path) + table_path = self._get_table_path(source_path) + + # Check cache first + if source_key in self._delta_table_cache: + return self._delta_table_cache[source_key] + + try: + # Try to load existing table + delta_table = DeltaTable(str(table_path)) + self._delta_table_cache[source_key] = delta_table + logger.debug(f"Loaded existing Delta table for {source_key}") + return delta_table + except TableNotFoundError: + # Table doesn't exist + return None + except Exception as e: + logger.error(f"Error loading Delta table for {source_key}: {e}") + # Try to clear any corrupted cache and retry once + if source_key in self._delta_table_cache: + del self._delta_table_cache[source_key] + return None + def _ensure_entry_id_column(self, arrow_data: pa.Table, entry_id: str) -> pa.Table: """Ensure the table has an __entry_id column.""" if "__entry_id" not in arrow_data.column_names: @@ -150,21 +402,199 @@ def _handle_entry_id_column( # If add_entry_id_column is True, keep __entry_id as is return arrow_data + def _create_entry_id_filter(self, entry_id: str) -> list: + """ + Create a proper filter expression for Delta Lake. + + Args: + entry_id: The entry ID to filter by + + Returns: + List containing the filter expression for Delta Lake + """ + return [("__entry_id", "=", entry_id)] + + def _create_entry_ids_filter(self, entry_ids: list[str]) -> list: + """ + Create a proper filter expression for multiple entry IDs. + + Args: + entry_ids: List of entry IDs to filter by + + Returns: + List containing the filter expression for Delta Lake + """ + return [("__entry_id", "in", entry_ids)] + + def _read_table_with_schema_preservation( + self, + delta_table: DeltaTable, + source_path: tuple[str, ...], + filters: list = None, + ) -> pa.Table: + """ + Read table using to_pyarrow_dataset with original schema preservation. + + Args: + delta_table: The Delta table to read from + source_path: Source path for schema lookup + filters: Optional filters to apply + + Returns: + Arrow table with preserved schema + """ + try: + # Get the original schema (without __entry_id) + original_schema = self._load_original_schema(source_path) + + if original_schema is not None: + # Create target schema with __entry_id column + entry_id_field = pa.field( + "__entry_id", pa.large_string(), nullable=False + ) + target_schema = pa.schema([entry_id_field] + list(original_schema)) + + # Use to_pyarrow_dataset with the target schema + dataset = delta_table.to_pyarrow_dataset(schema=target_schema) + if filters: + # Apply filters at dataset level for better performance + import pyarrow.compute as pc + + filter_expr = None + for filt in filters: + if len(filt) == 3: + col, op, val = filt + if op == "=": + expr = pc.equal(pc.field(col), pa.scalar(val)) + elif op == "in": + expr = pc.is_in(pc.field(col), pa.array(val)) + else: + # Fallback to table-level filtering + return delta_table.to_pyarrow_table(filters=filters) + + if filter_expr is None: + filter_expr = expr + else: + filter_expr = pc.and_(filter_expr, expr) + + if filter_expr is not None: + return dataset.to_table(filter=filter_expr) + + return dataset.to_table() + else: + # Fallback to regular method if no schema found + logger.warning( + f"No original schema found for {'/'.join(source_path)}, using fallback" + ) + return delta_table.to_pyarrow_table(filters=filters) + + except Exception as e: + logger.warning( + f"Error reading with schema preservation: {e}, falling back to regular method" + ) + return delta_table.to_pyarrow_table(filters=filters) + + def _flush_batch(self, source_path: tuple[str, ...]) -> None: + """ + Flush pending batch for a specific source path. + + Args: + source_path: Tuple of path components + """ + print("Flushing triggered!!", flush=True) + source_key = self._get_source_key(source_path) + + with self._batch_lock: + if ( + source_key not in self._pending_batches + or not self._pending_batches[source_key] + ): + return + + # Get all pending records + pending_tables = self._pending_batches[source_key] + self._pending_batches[source_key] = [] + + if not pending_tables: + return + + try: + # Combine all tables in the batch + combined_table = pa.concat_tables(pending_tables) + + table_path = self._get_table_path(source_path) + table_path.mkdir(parents=True, exist_ok=True) + + # Check if table exists + delta_table = self._get_or_create_delta_table(source_path) + + if delta_table is None: + # Create new table - save original schema first + original_schema = self._remove_entry_id_column(combined_table).schema + self._save_original_schema(source_path, original_schema) + + write_deltalake(str(table_path), combined_table, mode="overwrite") + logger.debug( + f"Created new Delta table for {source_key} with {len(combined_table)} records" + ) + else: + # Handle duplicates if needed + if self.duplicate_entry_behavior == "overwrite": + # Get entry IDs from the batch + entry_ids = combined_table.column("__entry_id").to_pylist() + unique_entry_ids = list(set(entry_ids)) + + # Delete existing records with these IDs + if unique_entry_ids: + entry_ids_str = "', '".join(unique_entry_ids) + delete_predicate = f"__entry_id IN ('{entry_ids_str}')" + try: + delta_table.delete(delete_predicate) + logger.debug( + f"Deleted {len(unique_entry_ids)} existing records from {source_key}" + ) + except Exception as e: + logger.debug( + f"No existing records to delete from {source_key}: {e}" + ) + + # Append new records + write_deltalake( + str(table_path), combined_table, mode="append", schema_mode="merge" + ) + logger.debug( + f"Appended batch of {len(combined_table)} records to {source_key}" + ) + + # Update cache + self._delta_table_cache[source_key] = DeltaTable(str(table_path)) + + except Exception as e: + logger.error(f"Error flushing batch for {source_key}: {e}") + # Put the tables back in the pending queue + with self._batch_lock: + self._pending_batches[source_key] = ( + pending_tables + self._pending_batches[source_key] + ) + raise + def add_record( self, source_path: tuple[str, ...], entry_id: str, arrow_data: pa.Table, ignore_duplicate: bool = False, + force_flush: bool = False, ) -> pa.Table: """ - Add a record to the Delta table. + Add a record to the Delta table (batched). Args: source_path: Tuple of path components (e.g., ("org", "project", "dataset")) entry_id: Unique identifier for this record arrow_data: The Arrow table data to store ignore_duplicate: If True, ignore duplicate entry error + force_flush: If True, immediately flush this record to disk Returns: The Arrow table data that was stored @@ -173,18 +603,15 @@ def add_record( ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' """ self._validate_source_path(source_path) - - table_path = self._get_table_path(source_path) source_key = self._get_source_key(source_path) - # Ensure directory exists - table_path.mkdir(parents=True, exist_ok=True) - - # Add entry_id column to the data - data_with_entry_id = self._ensure_entry_id_column(arrow_data, entry_id) - - # Check for existing entry if needed - if not ignore_duplicate and self.duplicate_entry_behavior == "error": + # Check for existing entry if needed (only for immediate duplicates, not batch) + if ( + not ignore_duplicate + and self.duplicate_entry_behavior == "error" + and not force_flush + ): + # Only check existing table, not pending batch for performance existing_record = self.get_record(source_path, entry_id) if existing_record is not None: raise ValueError( @@ -192,60 +619,121 @@ def add_record( f"Use duplicate_entry_behavior='overwrite' to allow updates." ) - try: - # Try to load existing table - delta_table = DeltaTable(str(table_path)) + # Save original schema if this is the first record for this source + if source_key not in self._schema_cache: + self._save_original_schema(source_path, arrow_data.schema) - if self.duplicate_entry_behavior == "overwrite": - # Delete existing record if it exists, then append new one - try: - # First, delete existing record with this entry_id - delta_table.delete(f"__entry_id = '{entry_id}'") - logger.debug( - f"Deleted existing record {entry_id} from {source_key}" - ) - except Exception as e: - # If delete fails (e.g., record doesn't exist), that's fine - logger.debug(f"No existing record to delete for {entry_id}: {e}") + # Add entry_id column to the data + data_with_entry_id = self._ensure_entry_id_column(arrow_data, entry_id) - # Append new record - write_deltalake( - str(table_path), data_with_entry_id, mode="append", schema_mode="merge" - ) + if force_flush: + # Write immediately + table_path = self._get_table_path(source_path) + table_path.mkdir(parents=True, exist_ok=True) + + delta_table = self._get_or_create_delta_table(source_path) + + if delta_table is None: + # Create new table - save original schema first + self._save_original_schema(source_path, arrow_data.schema) + write_deltalake(str(table_path), data_with_entry_id, mode="overwrite") + logger.debug(f"Created new Delta table for {source_key}") + else: + if self.duplicate_entry_behavior == "overwrite": + try: + delta_table.delete( + f"__entry_id = '{entry_id.replace(chr(39), chr(39) + chr(39))}'" + ) + logger.debug( + f"Deleted existing record {entry_id} from {source_key}" + ) + except Exception as e: + logger.debug( + f"No existing record to delete for {entry_id}: {e}" + ) + + write_deltalake( + str(table_path), + data_with_entry_id, + mode="append", + schema_mode="merge", + ) - except TableNotFoundError: - # Table doesn't exist, create it - write_deltalake(str(table_path), data_with_entry_id, mode="overwrite") - logger.debug(f"Created new Delta table for {source_key}") + # Update cache + self._delta_table_cache[source_key] = DeltaTable(str(table_path)) + else: + # Add to batch + with self._batch_lock: + self._pending_batches[source_key].append(data_with_entry_id) + batch_size = len(self._pending_batches[source_key]) - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) + # Check if we need to flush + if batch_size >= self.batch_size: + self._flush_batch(source_path) logger.debug(f"Added record {entry_id} to {source_key}") return arrow_data + def flush_batch(self, source_path: tuple[str, ...]) -> None: + """ + Manually flush pending batch for a specific source path. + + Args: + source_path: Tuple of path components + """ + self._flush_batch(source_path) + + def flush_all_batches(self) -> None: + """Flush all pending batches.""" + with self._batch_lock: + source_keys = list(self._pending_batches.keys()) + + for source_key in source_keys: + source_path = tuple(source_key.split("/")) + try: + self._flush_batch(source_path) + except Exception as e: + logger.error(f"Error flushing batch for {source_key}: {e}") + + def get_pending_batch_info(self) -> Dict[str, int]: + """ + Get information about pending batches. + + Returns: + Dictionary mapping source keys to number of pending records + """ + with self._batch_lock: + return { + source_key: len(tables) + for source_key, tables in self._pending_batches.items() + if tables + } + def get_record( self, source_path: tuple[str, ...], entry_id: str ) -> pa.Table | None: """ - Get a specific record by entry_id. + Get a specific record by entry_id with schema preservation. Args: source_path: Tuple of path components entry_id: Unique identifier for the record Returns: - Arrow table for the record, or None if not found + Arrow table for the record with original schema, or None if not found """ self._validate_source_path(source_path) - table_path = self._get_table_path(source_path) + delta_table = self._get_or_create_delta_table(source_path) + if delta_table is None: + return None try: - delta_table = DeltaTable(str(table_path)) - - # Query for the specific entry_id - result = delta_table.to_pyarrow_table(filter=f"__entry_id = '{entry_id}'") + # Use schema-preserving read + filter_expr = self._create_entry_id_filter(entry_id) + result = self._read_table_with_schema_preservation( + delta_table, source_path, filters=filter_expr + ) if len(result) == 0: return None @@ -253,19 +741,17 @@ def get_record( # Remove the __entry_id column before returning return self._remove_entry_id_column(result) - except TableNotFoundError: - return None except Exception as e: logger.error( f"Error getting record {entry_id} from {'/'.join(source_path)}: {e}" ) - return None + raise e def get_all_records( self, source_path: tuple[str, ...], add_entry_id_column: bool | str = False ) -> pa.Table | None: """ - Retrieve all records for a given source path as a single table. + Retrieve all records for a given source path as a single table with schema preservation. Args: source_path: Tuple of path components @@ -275,15 +761,17 @@ def get_all_records( - str: Include entry ID column with custom name Returns: - Arrow table containing all records, or None if no records found + Arrow table containing all records with original schema, or None if no records found """ self._validate_source_path(source_path) - table_path = self._get_table_path(source_path) + delta_table = self._get_or_create_delta_table(source_path) + if delta_table is None: + return None try: - delta_table = DeltaTable(str(table_path)) - result = delta_table.to_pyarrow_table() + # Use schema-preserving read + result = self._read_table_with_schema_preservation(delta_table, source_path) if len(result) == 0: return None @@ -291,8 +779,6 @@ def get_all_records( # Handle entry_id column based on parameter return self._handle_entry_id_column(result, add_entry_id_column) - except TableNotFoundError: - return None except Exception as e: logger.error(f"Error getting all records from {'/'.join(source_path)}: {e}") return None @@ -322,7 +808,7 @@ def get_records_by_ids( preserve_input_order: bool = False, ) -> pa.Table | None: """ - Retrieve records by entry IDs as a single table. + Retrieve records by entry IDs as a single table with schema preservation. Args: source_path: Tuple of path components @@ -331,7 +817,7 @@ def get_records_by_ids( preserve_input_order: If True, return results in input order with nulls for missing Returns: - Arrow table containing all found records, or None if no records found + Arrow table containing all found records with original schema, or None if no records found """ self._validate_source_path(source_path) @@ -353,16 +839,16 @@ def get_records_by_ids( f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" ) - table_path = self._get_table_path(source_path) + delta_table = self._get_or_create_delta_table(source_path) + if delta_table is None: + return None try: - delta_table = DeltaTable(str(table_path)) - - # Create filter for the entry IDs - escape single quotes in IDs - escaped_ids = [id_.replace("'", "''") for id_ in entry_ids_list] - id_filter = " OR ".join([f"__entry_id = '{id_}'" for id_ in escaped_ids]) - - result = delta_table.to_pyarrow_table(filter=id_filter) + # Use schema-preserving read with filters + filter_expr = self._create_entry_ids_filter(entry_ids_list) + result = self._read_table_with_schema_preservation( + delta_table, source_path, filters=filter_expr + ) if len(result) == 0: return None @@ -383,8 +869,6 @@ def get_records_by_ids( # Handle entry_id column based on parameter return self._handle_entry_id_column(result, add_entry_id_column) - except TableNotFoundError: - return None except Exception as e: logger.error( f"Error getting records by IDs from {'/'.join(source_path)}: {e}" @@ -462,6 +946,9 @@ def delete_source(self, source_path: tuple[str, ...]) -> bool: """ self._validate_source_path(source_path) + # Flush any pending batches first + self._flush_batch(source_path) + table_path = self._get_table_path(source_path) source_key = self._get_source_key(source_path) @@ -469,9 +956,11 @@ def delete_source(self, source_path: tuple[str, ...]) -> bool: return False try: - # Remove from cache + # Remove from caches if source_key in self._delta_table_cache: del self._delta_table_cache[source_key] + if source_key in self._schema_cache: + del self._schema_cache[source_key] # Remove directory import shutil @@ -498,21 +987,26 @@ def delete_record(self, source_path: tuple[str, ...], entry_id: str) -> bool: """ self._validate_source_path(source_path) - table_path = self._get_table_path(source_path) + # Flush any pending batches first + self._flush_batch(source_path) - try: - delta_table = DeltaTable(str(table_path)) + delta_table = self._get_or_create_delta_table(source_path) + if delta_table is None: + return False - # Check if record exists - escaped_entry_id = entry_id.replace("'", "''") - existing = delta_table.to_pyarrow_table( - filter=f"__entry_id = '{escaped_entry_id}'" + try: + # Check if record exists using proper filter + filter_expr = self._create_entry_id_filter(entry_id) + existing = self._read_table_with_schema_preservation( + delta_table, source_path, filters=filter_expr ) if len(existing) == 0: return False - # Delete the record - delta_table.delete(f"__entry_id = '{escaped_entry_id}'") + # Delete the record using SQL-style predicate (this is correct for delete operations) + delta_table.delete( + f"__entry_id = '{entry_id.replace(chr(39), chr(39) + chr(39))}'" + ) # Update cache source_key = self._get_source_key(source_path) @@ -521,8 +1015,6 @@ def delete_record(self, source_path: tuple[str, ...], entry_id: str) -> bool: logger.debug(f"Deleted record {entry_id} from {'/'.join(source_path)}") return True - except TableNotFoundError: - return False except Exception as e: logger.error( f"Error deleting record {entry_id} from {'/'.join(source_path)}: {e}" @@ -541,27 +1033,48 @@ def get_table_info(self, source_path: tuple[str, ...]) -> dict[str, Any] | None: """ self._validate_source_path(source_path) - table_path = self._get_table_path(source_path) + delta_table = self._get_or_create_delta_table(source_path) + if delta_table is None: + return None try: - delta_table = DeltaTable(str(table_path)) - # Get basic info schema = delta_table.schema() history = delta_table.history() + source_key = self._get_source_key(source_path) + + # Add pending batch info + pending_info = self.get_pending_batch_info() + pending_count = pending_info.get(source_key, 0) + + # Get original schema info + original_schema = self._load_original_schema(source_path) return { - "path": str(table_path), + "path": str(self._get_table_path(source_path)), "source_path": source_path, "schema": schema, + "original_schema": original_schema, "version": delta_table.version(), "num_files": len(delta_table.files()), "history_length": len(history), "latest_commit": history[0] if history else None, + "pending_records": pending_count, } - except TableNotFoundError: - return None except Exception as e: logger.error(f"Error getting table info for {'/'.join(source_path)}: {e}") return None + + def get_original_schema(self, source_path: tuple[str, ...]) -> pa.Schema | None: + """ + Get the original schema (without __entry_id column) for a source path. + + Args: + source_path: Tuple of path components + + Returns: + Original Arrow schema or None if not found + """ + self._validate_source_path(source_path) + return self._load_original_schema(source_path) From 1b7519e7a929e9e70695d5e257ee748f4787fa47 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:57:46 +0000 Subject: [PATCH 43/57] feat: better handling of stores and add flushing to stores and pipeline --- src/orcapod/pipeline/pipeline.py | 25 ++++++++++++++++++++++--- src/orcapod/stores/types.py | 4 ++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 7e04d96..a30aa62 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -34,13 +34,23 @@ def __init__( self, name: str | tuple[str, ...], pipeline_store: ArrowDataStore, - results_store: ArrowDataStore, + results_store: ArrowDataStore | None = None, auto_compile: bool = True, ) -> None: super().__init__() if not isinstance(name, tuple): name = (name,) self.name = name + self.pipeline_store_path_prefix = self.name + self.results_store_path_prefix = () + if results_store is None: + if pipeline_store is None: + raise ValueError( + "Either pipeline_store or results_store must be provided" + ) + results_store = pipeline_store + self.results_store_path_prefix = self.name + ("_results",) + self.pipeline_store = pipeline_store self.results_store = results_store self.labels_to_nodes = {} @@ -78,6 +88,12 @@ def save(self, path: Path | str) -> None: temp_path.unlink() raise + def flush(self) -> None: + """Flush all pending writes to the data store""" + self.pipeline_store.flush() + self.results_store.flush() + logger.info("Pipeline stores flushed") + def record(self, invocation: Invocation) -> None: """ Record an invocation in the pipeline. @@ -93,13 +109,14 @@ def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node input_nodes, output_store=self.results_store, tag_store=self.pipeline_store, - store_path_prefix=self.name, + output_store_path_prefix=self.results_store_path_prefix, + tag_store_path_prefix=self.pipeline_store_path_prefix, ) return KernelNode( kernel, input_nodes, output_store=self.pipeline_store, - store_path_prefix=self.name, + store_path_prefix=self.pipeline_store_path_prefix, ) def compile(self): @@ -175,6 +192,8 @@ def run(self, full_sync: bool = False) -> None: node.reset_cache() node.flow() + self.flush() + @classmethod def load(cls, path: Path | str) -> "Pipeline": """Load complete pipeline state""" diff --git a/src/orcapod/stores/types.py b/src/orcapod/stores/types.py index da7e492..42b0ed5 100644 --- a/src/orcapod/stores/types.py +++ b/src/orcapod/stores/types.py @@ -80,3 +80,7 @@ def get_records_by_ids_as_polars( ) -> pl.LazyFrame | None: """Retrieve records by entry IDs as a single Polars DataFrame.""" ... + + def flush(self) -> None: + """Flush all pending writes/saves to the data store.""" + ... From 07fd76e340bfd30b69578766b1917cd27eacee17 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:58:11 +0000 Subject: [PATCH 44/57] feat: integrate actual saving to parquet into simple in memory store --- src/orcapod/stores/arrow_data_stores.py | 205 ++++++++++++------------ 1 file changed, 102 insertions(+), 103 deletions(-) diff --git a/src/orcapod/stores/arrow_data_stores.py b/src/orcapod/stores/arrow_data_stores.py index 2897ead..9d001a6 100644 --- a/src/orcapod/stores/arrow_data_stores.py +++ b/src/orcapod/stores/arrow_data_stores.py @@ -8,6 +8,7 @@ from datetime import datetime, timedelta import logging from orcapod.stores.types import DuplicateError +from pathlib import Path # Module-level logger logger = logging.getLogger(__name__) @@ -101,7 +102,9 @@ class SimpleInMemoryDataStore: Uses dict of dict of Arrow tables for efficient storage and retrieval. """ - def __init__(self, duplicate_entry_behavior: str = "error"): + def __init__( + self, path: str | Path | None = None, duplicate_entry_behavior: str = "error" + ): """ Initialize the InMemoryArrowDataStore. @@ -120,6 +123,12 @@ def __init__(self, duplicate_entry_behavior: str = "error"): logger.info( f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" ) + self.base_path = Path(path) if path else None + if self.base_path: + try: + self.base_path.mkdir(parents=True, exist_ok=True) + except Exception as e: + logger.error(f"Error creating base path {self.base_path}: {e}") def _get_source_key(self, source_path: tuple[str, ...]) -> str: """Generate key for source storage.""" @@ -170,10 +179,16 @@ def add_record( logger.debug(f"{action} record {entry_id} in {source_key}") return arrow_data + def load_existing_record(self, source_path: tuple[str, ...]): + source_key = self._get_source_key(source_path) + if self.base_path is not None and source_key not in self._in_memory_store: + self.load_from_parquet(self.base_path, source_path) + def get_record( self, source_path: tuple[str, ...], entry_id: str ) -> pa.Table | None: """Get a specific record.""" + self.load_existing_record(source_path) source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) return local_data.get(entry_id) @@ -182,6 +197,7 @@ def get_all_records( self, source_path: tuple[str, ...], add_entry_id_column: bool | str = False ) -> pa.Table | None: """Retrieve all records for a given source as a single table.""" + self.load_existing_record(source_path) source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) @@ -257,6 +273,8 @@ def get_records_by_ids( f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" ) + self.load_existing_record(source_path) + source_key = self._get_source_key(source_path) local_data = self._in_memory_store.get(source_key, {}) @@ -394,19 +412,12 @@ def save_to_parquet(self, base_path: str | Path) -> None: saved_count = 0 - for source_key, local_data in self._in_memory_store.items(): + for source_id, local_data in self._in_memory_store.items(): if not local_data: continue - # Parse source_name and source_id from the key - if ":" not in source_key: - logger.warning(f"Invalid source key format: {source_key}, skipping") - continue - - source_name, source_id = source_key.split(":", 1) - # Create directory structure - source_dir = base_path / source_name / source_id + source_dir = base_path / source_id source_dir.mkdir(parents=True, exist_ok=True) # Combine all tables for this source with entry_id column @@ -430,12 +441,14 @@ def save_to_parquet(self, base_path: str | Path) -> None: saved_count += 1 logger.debug( - f"Saved {len(combined_table)} records for {source_key} to {parquet_path}" + f"Saved {len(combined_table)} records for {source_id} to {parquet_path}" ) logger.info(f"Saved {saved_count} sources to Parquet files in {base_path}") - def load_from_parquet(self, base_path: str | Path) -> None: + def load_from_parquet( + self, base_path: str | Path, source_path: tuple[str, ...] + ) -> None: """ Load data from Parquet files with the expected directory structure. @@ -444,113 +457,99 @@ def load_from_parquet(self, base_path: str | Path) -> None: Args: base_path: Base directory path containing the Parquet files """ - base_path = Path(base_path) - if not base_path.exists(): + source_key = self._get_source_key(source_path) + target_path = Path(base_path) / source_key + + if not target_path.exists(): logger.warning(f"Base path {base_path} does not exist") return - # Clear existing data - self._in_memory_store.clear() - loaded_count = 0 - # Traverse directory structure: source_name/source_id/ - for source_name_dir in base_path.iterdir(): - if not source_name_dir.is_dir(): - continue - - source_name = source_name_dir.name + # Look for Parquet files in this directory + parquet_files = list(target_path.glob("*.parquet")) + if not parquet_files: + logger.debug(f"No Parquet files found in {target_path}") + return - for source_id_dir in source_name_dir.iterdir(): - if not source_id_dir.is_dir(): - continue + # Load all Parquet files and combine them + all_records = [] - source_id = source_id_dir.name - source_key = self._get_source_key((source_name, source_id)) + for parquet_file in parquet_files: + try: + import pyarrow.parquet as pq - # Look for Parquet files in this directory - parquet_files = list(source_id_dir.glob("*.parquet")) + table = pq.read_table(parquet_file) - if not parquet_files: - logger.debug(f"No Parquet files found in {source_id_dir}") + # Validate that __entry_id column exists + if "__entry_id" not in table.column_names: + logger.warning( + f"Parquet file {parquet_file} missing __entry_id column, skipping" + ) continue - # Load all Parquet files and combine them - all_records = [] + all_records.append(table) + logger.debug(f"Loaded {len(table)} records from {parquet_file}") - for parquet_file in parquet_files: - try: - import pyarrow.parquet as pq - - table = pq.read_table(parquet_file) - - # Validate that __entry_id column exists - if "__entry_id" not in table.column_names: - logger.warning( - f"Parquet file {parquet_file} missing __entry_id column, skipping" - ) - continue - - all_records.append(table) - logger.debug(f"Loaded {len(table)} records from {parquet_file}") - - except Exception as e: - logger.error(f"Failed to load Parquet file {parquet_file}: {e}") - continue + except Exception as e: + logger.error(f"Failed to load Parquet file {parquet_file}: {e}") + continue - # Process all records for this source - if all_records: - # Combine all tables - if len(all_records) == 1: - combined_table = all_records[0] - else: - combined_table = pa.concat_tables(all_records) - - # Split back into individual records by entry_id - local_data = {} - entry_ids = combined_table.column("__entry_id").to_pylist() - - # Group records by entry_id - entry_id_groups = {} - for i, entry_id in enumerate(entry_ids): - if entry_id not in entry_id_groups: - entry_id_groups[entry_id] = [] - entry_id_groups[entry_id].append(i) - - # Extract each entry_id's records - for entry_id, indices in entry_id_groups.items(): - # Take rows for this entry_id and remove __entry_id column - entry_table = combined_table.take(indices) - - # Remove __entry_id column - column_names = entry_table.column_names - if "__entry_id" in column_names: - indices_to_keep = [ - i - for i, name in enumerate(column_names) - if name != "__entry_id" - ] - entry_table = entry_table.select(indices_to_keep) - - local_data[entry_id] = entry_table - - self._in_memory_store[source_key] = local_data - loaded_count += 1 - - record_count = len(combined_table) - unique_entries = len(entry_id_groups) - logger.debug( - f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}" - ) + # Process all records for this source + if all_records: + # Combine all tables + if len(all_records) == 1: + combined_table = all_records[0] + else: + combined_table = pa.concat_tables(all_records) + + # Split back into individual records by entry_id + local_data = {} + entry_ids = combined_table.column("__entry_id").to_pylist() + + # Group records by entry_id + entry_id_groups = {} + for i, entry_id in enumerate(entry_ids): + if entry_id not in entry_id_groups: + entry_id_groups[entry_id] = [] + entry_id_groups[entry_id].append(i) + + # Extract each entry_id's records + for entry_id, indices in entry_id_groups.items(): + # Take rows for this entry_id and remove __entry_id column + entry_table = combined_table.take(indices) + + # Remove __entry_id column + column_names = entry_table.column_names + if "__entry_id" in column_names: + indices_to_keep = [ + i for i, name in enumerate(column_names) if name != "__entry_id" + ] + entry_table = entry_table.select(indices_to_keep) + + local_data[entry_id] = entry_table + + self._in_memory_store[source_key] = local_data + loaded_count += 1 + + record_count = len(combined_table) + unique_entries = len(entry_id_groups) + logger.info( + f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}" + ) - logger.info(f"Loaded {loaded_count} sources from Parquet files in {base_path}") + def flush(self): + """ + Flush all in-memory data to Parquet files in the base path. + This will overwrite existing files. + """ + if self.base_path is None: + logger.warning("Base path is not set, cannot flush data") + return - # Log summary of loaded data - total_records = sum( - len(local_data) for local_data in self._in_memory_store.values() - ) - logger.info(f"Total records loaded: {total_records}") + logger.info(f"Flushing data to Parquet files in {self.base_path}") + self.save_to_parquet(self.base_path) @dataclass From 8411b40833ee28ced7a001be1545ea528004b995 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 2 Jul 2025 00:58:51 +0000 Subject: [PATCH 45/57] refactor: cleanup improt and comment out old packet converter for future removal --- src/orcapod/types/packet_converter.py | 364 +++++++++++++------------- src/orcapod/types/schemas.py | 1 - 2 files changed, 182 insertions(+), 183 deletions(-) diff --git a/src/orcapod/types/packet_converter.py b/src/orcapod/types/packet_converter.py index e486222..6edea00 100644 --- a/src/orcapod/types/packet_converter.py +++ b/src/orcapod/types/packet_converter.py @@ -1,182 +1,182 @@ -from orcapod.types.core import TypeSpec, TypeHandler -from orcapod.types.packets import Packet, PacketLike -from orcapod.types.semantic_type_registry import ( - SemanticTypeRegistry, - TypeInfo, - get_metadata_from_schema, - arrow_to_dicts, -) -from typing import Any -from collections.abc import Mapping, Sequence -import pyarrow as pa -import logging - -logger = logging.getLogger(__name__) - - -def is_packet_supported( - python_type_info: TypeSpec, - registry: SemanticTypeRegistry, - type_lut: dict | None = None, -) -> bool: - """Check if all types in the packet are supported by the registry or known to the default lut.""" - if type_lut is None: - type_lut = {} - return all( - python_type in registry or python_type in type_lut - for python_type in python_type_info.values() - ) - - -class PacketConverter: - def __init__(self, python_type_spec: TypeSpec, registry: SemanticTypeRegistry): - self.python_type_spec = python_type_spec - self.registry = registry - - # Lookup handlers and type info for fast access - self.handlers: dict[str, TypeHandler] = {} - self.storage_type_info: dict[str, TypeInfo] = {} - - self.expected_key_set = set(python_type_spec.keys()) - - # prepare the corresponding arrow table schema with metadata - self.keys_with_handlers, self.schema = create_schema_from_python_type_info( - python_type_spec, registry - ) - - self.semantic_type_lut = get_metadata_from_schema(self.schema, b"semantic_type") - - def _check_key_consistency(self, keys): - """Check if the provided keys match the expected keys.""" - keys_set = set(keys) - if keys_set != self.expected_key_set: - missing_keys = self.expected_key_set - keys_set - extra_keys = keys_set - self.expected_key_set - error_parts = [] - if missing_keys: - error_parts.append(f"Missing keys: {missing_keys}") - if extra_keys: - error_parts.append(f"Extra keys: {extra_keys}") - - raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") - - def _to_storage_packet(self, packet: PacketLike) -> dict[str, Any]: - """Convert packet to storage representation. - - Args: - packet: Dictionary mapping parameter names to Python values - - Returns: - Dictionary with same keys but values converted to storage format - - Raises: - KeyError: If packet keys don't match the expected type_info keys - TypeError: If value type doesn't match expected type - ValueError: If conversion fails - """ - # Validate packet keys - packet_keys = set(packet.keys()) - - self._check_key_consistency(packet_keys) - - # Convert each value - storage_packet: dict[str, Any] = dict(packet) # Start with a copy of the packet - - for key, handler in self.keys_with_handlers: - try: - storage_packet[key] = handler.python_to_storage(storage_packet[key]) - except Exception as e: - raise ValueError(f"Failed to convert value for '{key}': {e}") from e - - return storage_packet - - def _from_storage_packet(self, storage_packet: Mapping[str, Any]) -> PacketLike: - """Convert storage packet back to Python packet. - - Args: - storage_packet: Dictionary with values in storage format - - Returns: - Packet with values converted back to Python types - - Raises: - KeyError: If storage packet keys don't match the expected type_info keys - TypeError: If value type doesn't match expected type - ValueError: If conversion fails - """ - # Validate storage packet keys - storage_keys = set(storage_packet.keys()) - - self._check_key_consistency(storage_keys) - - # Convert each value back to Python type - packet: PacketLike = dict(storage_packet) - - for key, handler in self.keys_with_handlers: - try: - packet[key] = handler.storage_to_python(storage_packet[key]) - except Exception as e: - raise ValueError(f"Failed to convert value for '{key}': {e}") from e - - return packet - - def to_arrow_table(self, packet: PacketLike | Sequence[PacketLike]) -> pa.Table: - """Convert packet to PyArrow Table with field metadata. - - Args: - packet: Dictionary mapping parameter names to Python values - - Returns: - PyArrow Table with the packet data as a single row - """ - # Convert packet to storage format - if not isinstance(packet, Sequence): - packets = [packet] - else: - packets = packet - - storage_packets = [self._to_storage_packet(p) for p in packets] - - # Create arrays - arrays = [] - for field in self.schema: - values = [p[field.name] for p in storage_packets] - array = pa.array(values, type=field.type) - arrays.append(array) - - return pa.Table.from_arrays(arrays, schema=self.schema) - - def from_arrow_table( - self, table: pa.Table, verify_semantic_equivalence: bool = True - ) -> list[Packet]: - """Convert Arrow table to packet with field metadata. - - Args: - table: PyArrow Table with metadata - - Returns: - List of packets converted from the Arrow table - """ - # Check for consistency in the semantic type mapping: - semantic_type_info = get_metadata_from_schema(table.schema, b"semantic_type") - - if semantic_type_info != self.semantic_type_lut: - if not verify_semantic_equivalence: - logger.warning( - "Arrow table semantic types do not match expected type registry. " - f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" - ) - else: - raise ValueError( - "Arrow table semantic types do not match expected type registry. " - f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" - ) - - # Create packets from the Arrow table - # TODO: make this more efficient - storage_packets: list[Packet] = arrow_to_dicts(table) # type: ignore - if not self.keys_with_handlers: - # no special handling required - return storage_packets - - return [Packet(self._from_storage_packet(packet)) for packet in storage_packets] +# from orcapod.types.core import TypeSpec, TypeHandler +# from orcapod.types.packets import Packet, PacketLike +# from orcapod.types.semantic_type_registry import ( +# SemanticTypeRegistry, +# TypeInfo, +# get_metadata_from_schema, +# arrow_to_dicts, +# ) +# from typing import Any +# from collections.abc import Mapping, Sequence +# import pyarrow as pa +# import logging + +# logger = logging.getLogger(__name__) + + +# def is_packet_supported( +# python_type_info: TypeSpec, +# registry: SemanticTypeRegistry, +# type_lut: dict | None = None, +# ) -> bool: +# """Check if all types in the packet are supported by the registry or known to the default lut.""" +# if type_lut is None: +# type_lut = {} +# return all( +# python_type in registry or python_type in type_lut +# for python_type in python_type_info.values() +# ) + + +# class PacketConverter: +# def __init__(self, python_type_spec: TypeSpec, registry: SemanticTypeRegistry): +# self.python_type_spec = python_type_spec +# self.registry = registry + +# # Lookup handlers and type info for fast access +# self.handlers: dict[str, TypeHandler] = {} +# self.storage_type_info: dict[str, TypeInfo] = {} + +# self.expected_key_set = set(python_type_spec.keys()) + +# # prepare the corresponding arrow table schema with metadata +# self.keys_with_handlers, self.schema = create_schema_from_python_type_info( +# python_type_spec, registry +# ) + +# self.semantic_type_lut = get_metadata_from_schema(self.schema, b"semantic_type") + +# def _check_key_consistency(self, keys): +# """Check if the provided keys match the expected keys.""" +# keys_set = set(keys) +# if keys_set != self.expected_key_set: +# missing_keys = self.expected_key_set - keys_set +# extra_keys = keys_set - self.expected_key_set +# error_parts = [] +# if missing_keys: +# error_parts.append(f"Missing keys: {missing_keys}") +# if extra_keys: +# error_parts.append(f"Extra keys: {extra_keys}") + +# raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") + +# def _to_storage_packet(self, packet: PacketLike) -> dict[str, Any]: +# """Convert packet to storage representation. + +# Args: +# packet: Dictionary mapping parameter names to Python values + +# Returns: +# Dictionary with same keys but values converted to storage format + +# Raises: +# KeyError: If packet keys don't match the expected type_info keys +# TypeError: If value type doesn't match expected type +# ValueError: If conversion fails +# """ +# # Validate packet keys +# packet_keys = set(packet.keys()) + +# self._check_key_consistency(packet_keys) + +# # Convert each value +# storage_packet: dict[str, Any] = dict(packet) # Start with a copy of the packet + +# for key, handler in self.keys_with_handlers: +# try: +# storage_packet[key] = handler.python_to_storage(storage_packet[key]) +# except Exception as e: +# raise ValueError(f"Failed to convert value for '{key}': {e}") from e + +# return storage_packet + +# def _from_storage_packet(self, storage_packet: Mapping[str, Any]) -> PacketLike: +# """Convert storage packet back to Python packet. + +# Args: +# storage_packet: Dictionary with values in storage format + +# Returns: +# Packet with values converted back to Python types + +# Raises: +# KeyError: If storage packet keys don't match the expected type_info keys +# TypeError: If value type doesn't match expected type +# ValueError: If conversion fails +# """ +# # Validate storage packet keys +# storage_keys = set(storage_packet.keys()) + +# self._check_key_consistency(storage_keys) + +# # Convert each value back to Python type +# packet: PacketLike = dict(storage_packet) + +# for key, handler in self.keys_with_handlers: +# try: +# packet[key] = handler.storage_to_python(storage_packet[key]) +# except Exception as e: +# raise ValueError(f"Failed to convert value for '{key}': {e}") from e + +# return packet + +# def to_arrow_table(self, packet: PacketLike | Sequence[PacketLike]) -> pa.Table: +# """Convert packet to PyArrow Table with field metadata. + +# Args: +# packet: Dictionary mapping parameter names to Python values + +# Returns: +# PyArrow Table with the packet data as a single row +# """ +# # Convert packet to storage format +# if not isinstance(packet, Sequence): +# packets = [packet] +# else: +# packets = packet + +# storage_packets = [self._to_storage_packet(p) for p in packets] + +# # Create arrays +# arrays = [] +# for field in self.schema: +# values = [p[field.name] for p in storage_packets] +# array = pa.array(values, type=field.type) +# arrays.append(array) + +# return pa.Table.from_arrays(arrays, schema=self.schema) + +# def from_arrow_table( +# self, table: pa.Table, verify_semantic_equivalence: bool = True +# ) -> list[Packet]: +# """Convert Arrow table to packet with field metadata. + +# Args: +# table: PyArrow Table with metadata + +# Returns: +# List of packets converted from the Arrow table +# """ +# # Check for consistency in the semantic type mapping: +# semantic_type_info = get_metadata_from_schema(table.schema, b"semantic_type") + +# if semantic_type_info != self.semantic_type_lut: +# if not verify_semantic_equivalence: +# logger.warning( +# "Arrow table semantic types do not match expected type registry. " +# f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" +# ) +# else: +# raise ValueError( +# "Arrow table semantic types do not match expected type registry. " +# f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" +# ) + +# # Create packets from the Arrow table +# # TODO: make this more efficient +# storage_packets: list[Packet] = arrow_to_dicts(table) # type: ignore +# if not self.keys_with_handlers: +# # no special handling required +# return storage_packets + +# return [Packet(self._from_storage_packet(packet)) for packet in storage_packets] diff --git a/src/orcapod/types/schemas.py b/src/orcapod/types/schemas.py index dc2112f..35cc4f0 100644 --- a/src/orcapod/types/schemas.py +++ b/src/orcapod/types/schemas.py @@ -1,6 +1,5 @@ from orcapod.types import TypeSpec from orcapod.types.semantic_type_registry import SemanticTypeRegistry -from typing import Any import pyarrow as pa import datetime From d90e5c64f5615a5b4f14b960d7b01f3debc16e2e Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 19:43:15 +0000 Subject: [PATCH 46/57] fix: attach label on kernel invocation to the invocation object --- src/orcapod/core/base.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 2144c34..64b99cb 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -53,8 +53,6 @@ def post_forward_hook(self, output_stream: "SyncStream", **kwargs) -> "SyncStrea def __call__( self, *streams: "SyncStream", label: str | None = None, **kwargs ) -> "SyncStream": - if label is not None: - self.label = label # Special handling of Source: trigger call on source if passed as stream normalized_streams = [ stream() if isinstance(stream, Source) else stream for stream in streams @@ -64,7 +62,7 @@ def __call__( output_stream = self.forward(*pre_processed_streams, **kwargs) post_processed_stream = self.post_forward_hook(output_stream, **kwargs) # create an invocation instance - invocation = Invocation(self, pre_processed_streams) + invocation = Invocation(self, pre_processed_streams, label=label) # label the output_stream with the invocation that produced the stream post_processed_stream.invocation = invocation @@ -458,6 +456,7 @@ def map( packet_map: dict | None = None, tag_map: dict | None = None, drop_unmapped: bool = True, + label: str | None = None, ) -> "SyncStream": """ Returns a new stream that is the result of mapping the packets and tags in the stream. @@ -470,9 +469,11 @@ def map( output = self if packet_map is not None: - output = MapPackets(packet_map, drop_unmapped=drop_unmapped)(output) + output = MapPackets(packet_map, drop_unmapped=drop_unmapped, label=label)( + output + ) if tag_map is not None: - output = MapTags(tag_map, drop_unmapped=drop_unmapped)(output) + output = MapTags(tag_map, drop_unmapped=drop_unmapped, label=label)(output) return output From fe35abacab3fa1d60a55d646346fc3ed75369314 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 19:43:40 +0000 Subject: [PATCH 47/57] fix: invoke superclass init --- src/orcapod/core/operators.py | 40 +++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/orcapod/core/operators.py b/src/orcapod/core/operators.py index bcf63e3..5049e8e 100644 --- a/src/orcapod/core/operators.py +++ b/src/orcapod/core/operators.py @@ -23,8 +23,8 @@ class Repeat(Operator): The repeat count is the number of times to repeat each packet. """ - def __init__(self, repeat_count: int) -> None: - super().__init__() + def __init__(self, repeat_count: int, **kwargs) -> None: + super().__init__(**kwargs) if not isinstance(repeat_count, int): raise TypeError("repeat_count must be an integer") if repeat_count < 0: @@ -381,8 +381,10 @@ class MapPackets(Operator): drop_unmapped=False, in which case unmapped keys will be retained. """ - def __init__(self, key_map: dict[str, str], drop_unmapped: bool = True) -> None: - super().__init__() + def __init__( + self, key_map: dict[str, str], drop_unmapped: bool = True, **kwargs + ) -> None: + super().__init__(**kwargs) self.key_map = key_map self.drop_unmapped = drop_unmapped @@ -481,8 +483,8 @@ class DefaultTag(Operator): tag already contains the same key, it will not be overwritten. """ - def __init__(self, default_tag: Tag) -> None: - super().__init__() + def __init__(self, default_tag: Tag, **kwargs) -> None: + super().__init__(**kwargs) self.default_tag = default_tag def forward(self, *streams: SyncStream) -> SyncStream: @@ -527,8 +529,10 @@ class MapTags(Operator): drop_unmapped=False, in which case unmapped tags will be retained. """ - def __init__(self, key_map: dict[str, str], drop_unmapped: bool = True) -> None: - super().__init__() + def __init__( + self, key_map: dict[str, str], drop_unmapped: bool = True, **kwargs + ) -> None: + super().__init__(**kwargs) self.key_map = key_map self.drop_unmapped = drop_unmapped @@ -658,8 +662,8 @@ class Filter(Operator): The predicate function should return True for packets that should be kept and False for packets that should be dropped. """ - def __init__(self, predicate: Callable[[Tag, Packet], bool]): - super().__init__() + def __init__(self, predicate: Callable[[Tag, Packet], bool], **kwargs): + super().__init__(**kwargs) self.predicate = predicate def forward(self, *streams: SyncStream) -> SyncStream: @@ -704,8 +708,10 @@ class Transform(Operator): The transformation function should return a tuple of (new_tag, new_packet). """ - def __init__(self, transform: Callable[[Tag, Packet], tuple[Tag, Packet]]): - super().__init__() + def __init__( + self, transform: Callable[[Tag, Packet], tuple[Tag, Packet]], **kwargs + ): + super().__init__(**kwargs) self.transform = transform def forward(self, *streams: SyncStream) -> SyncStream: @@ -742,8 +748,9 @@ def __init__( batch_size: int, tag_processor: None | Callable[[Collection[Tag]], Tag] = None, drop_last: bool = True, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.batch_size = batch_size if tag_processor is None: tag_processor = batch_tags # noqa: E731 @@ -806,8 +813,9 @@ def __init__( reduce_keys: bool = False, selection_function: Callable[[Collection[tuple[Tag, Packet]]], Collection[bool]] | None = None, + **kwargs, ) -> None: - super().__init__() + super().__init__(**kwargs) self.group_keys = group_keys self.reduce_keys = reduce_keys self.selection_function = selection_function @@ -875,8 +883,8 @@ class CacheStream(Operator): Call `clear_cache()` to clear the cache. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self.cache: list[tuple[Tag, Packet]] = [] self.is_cached = False From ef301b38cce0528cef80df6861bbee3477d0fe46 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 19:44:13 +0000 Subject: [PATCH 48/57] feat: expose explicit check for assigned label on content identifiable base --- src/orcapod/hashing/content_identifiable.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/orcapod/hashing/content_identifiable.py b/src/orcapod/hashing/content_identifiable.py index ce1b6c3..1e48243 100644 --- a/src/orcapod/hashing/content_identifiable.py +++ b/src/orcapod/hashing/content_identifiable.py @@ -30,6 +30,16 @@ def __init__( ) self._label = label + @property + def has_assigned_label(self) -> bool: + """ + Check if the label is explicitly set for this object. + + Returns: + bool: True if the label is explicitly set, False otherwise. + """ + return self._label is not None + @property def label(self) -> str: """ From ead67045e11b32fe10603d783012896c428564bf Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 19:44:37 +0000 Subject: [PATCH 49/57] feat: add label on wrapped invocation --- src/orcapod/pipeline/pipeline.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index a30aa62..2a7d86e 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -102,7 +102,9 @@ def record(self, invocation: Invocation) -> None: super().record(invocation) self._dirty = True - def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node: + def wrap_invocation( + self, kernel: Kernel, input_nodes: Collection[Node], label: str | None = None + ) -> Node: if isinstance(kernel, FunctionPod): return FunctionPodNode( kernel, @@ -111,12 +113,14 @@ def wrap_invocation(self, kernel: Kernel, input_nodes: Collection[Node]) -> Node tag_store=self.pipeline_store, output_store_path_prefix=self.results_store_path_prefix, tag_store_path_prefix=self.pipeline_store_path_prefix, + label=label, ) return KernelNode( kernel, input_nodes, output_store=self.pipeline_store, store_path_prefix=self.pipeline_store_path_prefix, + label=label, ) def compile(self): @@ -133,7 +137,11 @@ def compile(self): for invocation in nx.topological_sort(G): # map streams to the new streams based on Nodes input_nodes = [edge_lut[stream] for stream in invocation.streams] - new_node = self.wrap_invocation(invocation.kernel, input_nodes) + label = None + if invocation.has_assigned_label: + # If the invocation has a label, use it directly + label = invocation.label + new_node = self.wrap_invocation(invocation.kernel, input_nodes, label=label) # register the new node against the original invocation node_lut[invocation] = new_node From cbb8754bfd53ecdea067ed71306f34ce15e4efd7 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:44:14 +0000 Subject: [PATCH 50/57] doc: add tutorial notebook --- .../01_quick_dive_into_orcapod.ipynb | 821 ++++++++++++++++++ 1 file changed, 821 insertions(+) create mode 100644 notebooks/tutorials/01_quick_dive_into_orcapod.ipynb diff --git a/notebooks/tutorials/01_quick_dive_into_orcapod.ipynb b/notebooks/tutorials/01_quick_dive_into_orcapod.ipynb new file mode 100644 index 0000000..2f99783 --- /dev/null +++ b/notebooks/tutorials/01_quick_dive_into_orcapod.ipynb @@ -0,0 +1,821 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "27cdd37d", + "metadata": {}, + "outputs": [], + "source": [ + "import orcapod as op" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e776b8dc", + "metadata": {}, + "outputs": [], + "source": [ + "N = 10\n", + "stream = op.SyncStreamFromLists(\n", + " tags=[{\"id\": i} for i in range(N)],\n", + " packets=[{\"x\": i, \"y\": i + 1} for i in range(N)],\n", + " tag_typespec={\"id\": int},\n", + " packet_typespec={\"x\": int, \"y\": int},\n", + " label=\"MySource\",\n", + ")\n", + "\n", + "word_stream = op.SyncStreamFromLists(\n", + " tags=[{\"id\": i} for i in range(N)],\n", + " packets=[{\"word1\": f\"hello {i}\", \"word2\": f\"world {i}\"} for i in range(N)],\n", + " tag_typespec={\"id\": int},\n", + " packet_typespec={\"word1\": str, \"word2\": str},\n", + " label=\"HelloWorld\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78ab941b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 0} {'x': 0, 'y': 1}\n", + "{'id': 1} {'x': 1, 'y': 2}\n", + "{'id': 2} {'x': 2, 'y': 3}\n", + "{'id': 3} {'x': 3, 'y': 4}\n", + "{'id': 4} {'x': 4, 'y': 5}\n", + "{'id': 5} {'x': 5, 'y': 6}\n", + "{'id': 6} {'x': 6, 'y': 7}\n", + "{'id': 7} {'x': 7, 'y': 8}\n", + "{'id': 8} {'x': 8, 'y': 9}\n", + "{'id': 9} {'x': 9, 'y': 10}\n" + ] + } + ], + "source": [ + "for tag, packet in stream:\n", + " print(tag, packet)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c32596f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 0} {'word1': 'hello 0', 'word2': 'world 0'}\n", + "{'id': 1} {'word1': 'hello 1', 'word2': 'world 1'}\n", + "{'id': 2} {'word1': 'hello 2', 'word2': 'world 2'}\n", + "{'id': 3} {'word1': 'hello 3', 'word2': 'world 3'}\n", + "{'id': 4} {'word1': 'hello 4', 'word2': 'world 4'}\n", + "{'id': 5} {'word1': 'hello 5', 'word2': 'world 5'}\n", + "{'id': 6} {'word1': 'hello 6', 'word2': 'world 6'}\n", + "{'id': 7} {'word1': 'hello 7', 'word2': 'world 7'}\n", + "{'id': 8} {'word1': 'hello 8', 'word2': 'world 8'}\n", + "{'id': 9} {'word1': 'hello 9', 'word2': 'world 9'}\n" + ] + } + ], + "source": [ + "for tag, packet in word_stream:\n", + " print(tag, packet)" + ] + }, + { + "cell_type": "markdown", + "id": "ea7eb5ed", + "metadata": {}, + "source": [ + "## Defining function pods" + ] + }, + { + "cell_type": "markdown", + "id": "891bbadf", + "metadata": {}, + "source": [ + "Now we define our own function pods to perform simple computation. \n", + "Defining a function pod is quite simple, you simply \n", + "1. define a regular function with type annotations\n", + "2. decorate with `op.function_pod`, passing in the name ('key') for the output value(s)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8f5d5dbc", + "metadata": {}, + "outputs": [], + "source": [ + "@op.function_pod(\"total\")\n", + "def total(x: int, y: int) -> int:\n", + " return x + y\n", + "\n", + "\n", + "@op.function_pod(\"delta\")\n", + "def delta(x: int, y: int) -> int:\n", + " return 2 * y - x\n", + "\n", + "\n", + "@op.function_pod(\"mult\")\n", + "def mult(x: int, y: int) -> int:\n", + " return x * y\n", + "\n", + "\n", + "@op.function_pod(\"concat_string\")\n", + "def concat(x: str, y: str) -> str:\n", + " return x + y\n" + ] + }, + { + "cell_type": "markdown", + "id": "bd843166", + "metadata": {}, + "source": [ + "Wrapped functions are now `FunctionPod` and expects to be called with streams as inputs. You can still access the original function through its `function` attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c0a191b2", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Expected SyncStream, got int for stream 5", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# this won't work, because it's expecting a stream as input\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mtotal\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m6\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/workspace/orcapod-python/src/orcapod/core/base.py:60\u001b[39m, in \u001b[36mKernel.__call__\u001b[39m\u001b[34m(self, label, *streams, **kwargs)\u001b[39m\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m stream \u001b[38;5;129;01min\u001b[39;00m streams:\n\u001b[32m 59\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(stream, SyncStream):\n\u001b[32m---> \u001b[39m\u001b[32m60\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 61\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExpected SyncStream, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(stream).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for stream \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstream\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 62\u001b[39m )\n\u001b[32m 63\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(stream, Source):\n\u001b[32m 64\u001b[39m \u001b[38;5;66;03m# if the stream is a Source, instantiate it\u001b[39;00m\n\u001b[32m 65\u001b[39m stream = stream()\n", + "\u001b[31mTypeError\u001b[39m: Expected SyncStream, got int for stream 5" + ] + } + ], + "source": [ + "# this won't work, because it's expecting a stream as input\n", + "total(5, 6)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88a9b698", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "11" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# but you can access original function this way\n", + "total.function(5, 6)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c8ad097f", + "metadata": {}, + "outputs": [], + "source": [ + "# Passing a stream into a pod does NOT immediately trigger execution, but rather returns another stream\n", + "\n", + "total_stream = total(stream)" + ] + }, + { + "cell_type": "markdown", + "id": "0af7a165", + "metadata": {}, + "source": [ + "Iterating through the stream or calling `flow` triggers the computation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "93c3f1a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 0} {'total': 1}\n", + "{'id': 1} {'total': 3}\n", + "{'id': 2} {'total': 5}\n", + "{'id': 3} {'total': 7}\n", + "{'id': 4} {'total': 9}\n", + "{'id': 5} {'total': 11}\n", + "{'id': 6} {'total': 13}\n", + "{'id': 7} {'total': 15}\n", + "{'id': 8} {'total': 17}\n", + "{'id': 9} {'total': 19}\n" + ] + } + ], + "source": [ + "for tag, packet in total_stream:\n", + " print(tag, packet)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cfadfb8f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[({'id': 0}, {'total': 1}),\n", + " ({'id': 1}, {'total': 3}),\n", + " ({'id': 2}, {'total': 5}),\n", + " ({'id': 3}, {'total': 7}),\n", + " ({'id': 4}, {'total': 9}),\n", + " ({'id': 5}, {'total': 11}),\n", + " ({'id': 6}, {'total': 13}),\n", + " ({'id': 7}, {'total': 15}),\n", + " ({'id': 8}, {'total': 17}),\n", + " ({'id': 9}, {'total': 19})]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_stream.flow()" + ] + }, + { + "cell_type": "markdown", + "id": "d1013dd1", + "metadata": {}, + "source": [ + "If you try to pass in an incompatible stream (stream whose packets don't match the expected inputs of the function), you will immediately get an error." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2805282e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Key 'word1' not found in parameter types.\n" + ] + }, + { + "ename": "TypeError", + "evalue": "Input packet types {'word1': , 'word2': } is not compatible with the function's expected input types {'x': , 'y': }", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m total_stream = \u001b[43mtotal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mword_stream\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/workspace/orcapod-python/src/orcapod/core/base.py:75\u001b[39m, in \u001b[36mKernel.__call__\u001b[39m\u001b[34m(self, label, *streams, **kwargs)\u001b[39m\n\u001b[32m 69\u001b[39m normalized_streams = [\n\u001b[32m 70\u001b[39m stream() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(stream, Source) \u001b[38;5;28;01melse\u001b[39;00m stream\n\u001b[32m 71\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m stream \u001b[38;5;129;01min\u001b[39;00m verified_streams\n\u001b[32m 72\u001b[39m ]\n\u001b[32m 74\u001b[39m pre_processed_streams = \u001b[38;5;28mself\u001b[39m.pre_forward_hook(*normalized_streams, **kwargs)\n\u001b[32m---> \u001b[39m\u001b[32m75\u001b[39m output_stream = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mpre_processed_streams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 76\u001b[39m post_processed_stream = \u001b[38;5;28mself\u001b[39m.post_forward_hook(output_stream, **kwargs)\n\u001b[32m 77\u001b[39m \u001b[38;5;66;03m# create an invocation instance\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/workspace/orcapod-python/src/orcapod/core/pod.py:236\u001b[39m, in \u001b[36mFunctionPod.forward\u001b[39m\u001b[34m(self, *streams, **kwargs)\u001b[39m\n\u001b[32m 232\u001b[39m _, packet_typespec = stream.types(trigger_run=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 233\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m packet_typespec \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m check_typespec_compatibility(\n\u001b[32m 234\u001b[39m packet_typespec, \u001b[38;5;28mself\u001b[39m.function_input_typespec\n\u001b[32m 235\u001b[39m ):\n\u001b[32m--> \u001b[39m\u001b[32m236\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 237\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInput packet types \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpacket_typespec\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m is not compatible with the function\u001b[39m\u001b[33m'\u001b[39m\u001b[33ms expected input types \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.function_input_typespec\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 238\u001b[39m )\n\u001b[32m 239\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m().forward(*streams, **kwargs)\n", + "\u001b[31mTypeError\u001b[39m: Input packet types {'word1': , 'word2': } is not compatible with the function's expected input types {'x': , 'y': }" + ] + } + ], + "source": [ + "total_stream = total(word_stream)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4c9c030a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "({'id': int}, {'x': int, 'y': int})" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# you can check the tag and packet types of the stream\n", + "stream.types()" + ] + }, + { + "cell_type": "markdown", + "id": "3ba299b2", + "metadata": {}, + "source": [ + "## Defining pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "1e1dd036", + "metadata": {}, + "source": [ + "We will now piece together multiple function pods into a pipeline. We do this by instantiating a `Pipeline` object. We will store the results into a simple data store." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8083f54a", + "metadata": {}, + "outputs": [], + "source": [ + "# Use simple data store, saving data to Parquet files\n", + "pipeline_store = op.stores.SimpleParquetDataStore(\"./example_data_store\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a475308c", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = op.Pipeline(\"test_pipeline\", pipeline_store)\n" + ] + }, + { + "cell_type": "markdown", + "id": "a42158b9", + "metadata": {}, + "source": [ + "Now we have a pipeline object, we can use it to define our pipeline by simply \"chaining\" together function pod calls." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f923ecf1", + "metadata": {}, + "outputs": [], + "source": [ + "with pipeline:\n", + " total_stream = total(stream)\n", + " delta_stream = delta(stream)\n", + " mult_stream = mult(\n", + " total_stream.map({\"total\": \"x\"}), delta_stream.map({\"delta\": \"y\"})\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "b67e9413", + "metadata": {}, + "source": [ + "And that's it! Now the elements of the pipeline is available as properties on the pipeline." + ] + }, + { + "cell_type": "markdown", + "id": "7ee41a20", + "metadata": {}, + "source": [ + "By default, the function pods are made available under the function's name in the pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "66230603", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FunctionPodNode>" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.total" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6587f2f2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FunctionPodNode>" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.mult" + ] + }, + { + "cell_type": "markdown", + "id": "16d0dba3", + "metadata": {}, + "source": [ + "Other implicitly created nodes such as joining of two streams are made available under the corresponding operator class (e.g. `Join`)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bd0dfba2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "KernelNode" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.Join" + ] + }, + { + "cell_type": "markdown", + "id": "71dba5c5", + "metadata": {}, + "source": [ + "You can list out all nodes through `nodes` property" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e22758ab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'MySource': KernelNode,\n", + " 'total': FunctionPodNode>,\n", + " 'delta': FunctionPodNode>,\n", + " 'MapPackets_0': KernelNode,\n", + " 'MapPackets_1': KernelNode,\n", + " 'Join': KernelNode,\n", + " 'mult': FunctionPodNode>}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.nodes" + ] + }, + { + "cell_type": "markdown", + "id": "039b617f", + "metadata": {}, + "source": [ + "You can easily rename any node using the pipeline's `rename` method" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0d1a470e", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.rename(\"MapPackets_0\", \"total_map\")\n", + "pipeline.rename(\"MapPackets_1\", \"mult_map\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "3a43984d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'MySource': KernelNode,\n", + " 'total': FunctionPodNode>,\n", + " 'delta': FunctionPodNode>,\n", + " 'Join': KernelNode,\n", + " 'mult': FunctionPodNode>,\n", + " 'total_map': KernelNode,\n", + " 'mult_map': KernelNode}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.nodes" + ] + }, + { + "cell_type": "markdown", + "id": "c438f111", + "metadata": {}, + "source": [ + "Renaming does NOT change the structure of the pipeline in anyway -- it simply changes how it's labeld for your convenience." + ] + }, + { + "cell_type": "markdown", + "id": "befa6107", + "metadata": {}, + "source": [ + "### Running pipeline and accessing results" + ] + }, + { + "cell_type": "markdown", + "id": "4d4412b1", + "metadata": {}, + "source": [ + "Since we just created the pipeline, there are no results associated with any node. You can get [Polars](https://pola.rs) DataFrame viewing into the results through the node's `df` attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "96106e09", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (10, 2)
idtotal
i64i64
01
13
25
37
49
511
613
715
817
919
" + ], + "text/plain": [ + "shape: (10, 2)\n", + "┌─────┬───────┐\n", + "│ id ┆ total │\n", + "│ --- ┆ --- │\n", + "│ i64 ┆ i64 │\n", + "╞═════╪═══════╡\n", + "│ 0 ┆ 1 │\n", + "│ 1 ┆ 3 │\n", + "│ 2 ┆ 5 │\n", + "│ 3 ┆ 7 │\n", + "│ 4 ┆ 9 │\n", + "│ 5 ┆ 11 │\n", + "│ 6 ┆ 13 │\n", + "│ 7 ┆ 15 │\n", + "│ 8 ┆ 17 │\n", + "│ 9 ┆ 19 │\n", + "└─────┴───────┘" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.total.df" + ] + }, + { + "cell_type": "markdown", + "id": "62b7e59a", + "metadata": {}, + "source": [ + "Before we run, the source nodes is also not \"recorded\" and thus will appear empty." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "33b449b6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (10, 3)
idxy
i64i64i64
001
112
223
334
445
556
667
778
889
9910
" + ], + "text/plain": [ + "shape: (10, 3)\n", + "┌─────┬─────┬─────┐\n", + "│ id ┆ x ┆ y │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════╪═════╪═════╡\n", + "│ 0 ┆ 0 ┆ 1 │\n", + "│ 1 ┆ 1 ┆ 2 │\n", + "│ 2 ┆ 2 ┆ 3 │\n", + "│ 3 ┆ 3 ┆ 4 │\n", + "│ 4 ┆ 4 ┆ 5 │\n", + "│ 5 ┆ 5 ┆ 6 │\n", + "│ 6 ┆ 6 ┆ 7 │\n", + "│ 7 ┆ 7 ┆ 8 │\n", + "│ 8 ┆ 8 ┆ 9 │\n", + "│ 9 ┆ 9 ┆ 10 │\n", + "└─────┴─────┴─────┘" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.MySource.df" + ] + }, + { + "cell_type": "markdown", + "id": "408e8012", + "metadata": {}, + "source": [ + "We can trigger the entire pipeline to run and record all results by simply calling the `run` method." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "189f943f", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "1674bec4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (10, 3)
idxy
i64i64i64
001
112
223
334
445
556
667
778
889
9910
" + ], + "text/plain": [ + "shape: (10, 3)\n", + "┌─────┬─────┬─────┐\n", + "│ id ┆ x ┆ y │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════╪═════╪═════╡\n", + "│ 0 ┆ 0 ┆ 1 │\n", + "│ 1 ┆ 1 ┆ 2 │\n", + "│ 2 ┆ 2 ┆ 3 │\n", + "│ 3 ┆ 3 ┆ 4 │\n", + "│ 4 ┆ 4 ┆ 5 │\n", + "│ 5 ┆ 5 ┆ 6 │\n", + "│ 6 ┆ 6 ┆ 7 │\n", + "│ 7 ┆ 7 ┆ 8 │\n", + "│ 8 ┆ 8 ┆ 9 │\n", + "│ 9 ┆ 9 ┆ 10 │\n", + "└─────┴─────┴─────┘" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.MySource.df" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "2b69d213", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (10, 2)
idtotal
i64i64
01
13
25
37
49
511
613
715
817
919
" + ], + "text/plain": [ + "shape: (10, 2)\n", + "┌─────┬───────┐\n", + "│ id ┆ total │\n", + "│ --- ┆ --- │\n", + "│ i64 ┆ i64 │\n", + "╞═════╪═══════╡\n", + "│ 0 ┆ 1 │\n", + "│ 1 ┆ 3 │\n", + "│ 2 ┆ 5 │\n", + "│ 3 ┆ 7 │\n", + "│ 4 ┆ 9 │\n", + "│ 5 ┆ 11 │\n", + "│ 6 ┆ 13 │\n", + "│ 7 ┆ 15 │\n", + "│ 8 ┆ 17 │\n", + "│ 9 ┆ 19 │\n", + "└─────┴───────┘" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.total.df" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "orcapod (3.13.3)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 73b2638aec7fe45989b8a35c805ee7ac8e777019 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:45:01 +0000 Subject: [PATCH 51/57] refactor: clean up store package --- src/orcapod/stores/arrow_data_stores.py | 28 +- ...a_store.py => dict_transfer_data_store.py} | 2 +- src/orcapod/stores/optimized_memory_store.py | 445 ------------------ 3 files changed, 20 insertions(+), 455 deletions(-) rename src/orcapod/stores/{transfer_data_store.py => dict_transfer_data_store.py} (96%) delete mode 100644 src/orcapod/stores/optimized_memory_store.py diff --git a/src/orcapod/stores/arrow_data_stores.py b/src/orcapod/stores/arrow_data_stores.py index 9d001a6..93a2400 100644 --- a/src/orcapod/stores/arrow_data_stores.py +++ b/src/orcapod/stores/arrow_data_stores.py @@ -17,8 +17,11 @@ class MockArrowDataStore: """ Mock Arrow data store for testing purposes. - This class simulates the behavior of ParquetArrowDataStore without actually saving anything. - It is useful for unit tests where you want to avoid filesystem dependencies. + This class simulates the behavior of ArrowDataStore without actually saving anything. + It is useful for unit tests where you want to avoid any I/O operations or when you need + to test the behavior of your code without relying on external systems. If you need some + persistence of saved data, consider using SimpleParquetDataStore without providing a + file path instead. """ def __init__(self): @@ -93,13 +96,20 @@ def get_records_by_ids_as_polars( return None -class SimpleInMemoryDataStore: +class SimpleParquetDataStore: """ - In-memory Arrow data store, primarily to be used for testing purposes. - This class simulates the behavior of ParquetArrowDataStore without actual file I/O. - It is useful for unit tests where you want to avoid filesystem dependencies. - - Uses dict of dict of Arrow tables for efficient storage and retrieval. + Simple Parquet-based Arrow data store, primarily to be used for development purposes. + If no file path is provided, it will not save anything to disk. Instead, all data will be stored in memory. + If a file path is provided, it will save data to a single Parquet files in a directory structure reflecting + the provided source_path. To speed up the process, data will be stored in memory and only saved to disk + when the `flush` method is called. If used as part of pipeline, flush is automatically called + at the end of pipeline execution. + Note that this store provides only very basic functionality and is not suitable for production use. + For each distinct source_path, only a single parquet file is created to store all data entries. + Appending is not efficient as it requires reading the entire file into the memory, appending new data, + and then writing the entire file back to disk. This is not suitable for large datasets or frequent updates. + However, for development/testing purposes, this data store provides a simple way to store and retrieve + data without the overhead of a full database or file system and provides very high performance. """ def __init__( @@ -462,7 +472,7 @@ def load_from_parquet( target_path = Path(base_path) / source_key if not target_path.exists(): - logger.warning(f"Base path {base_path} does not exist") + logger.info(f"Base path {base_path} does not exist") return loaded_count = 0 diff --git a/src/orcapod/stores/transfer_data_store.py b/src/orcapod/stores/dict_transfer_data_store.py similarity index 96% rename from src/orcapod/stores/transfer_data_store.py rename to src/orcapod/stores/dict_transfer_data_store.py index 9e393e0..7e8762f 100644 --- a/src/orcapod/stores/transfer_data_store.py +++ b/src/orcapod/stores/dict_transfer_data_store.py @@ -6,7 +6,7 @@ class TransferDataStore(DataStore): """ - A data store that allows transferring memoized packets between different data stores. + A data store that allows transferring recorded data between different data stores. This is useful for moving data between different storage backends. """ diff --git a/src/orcapod/stores/optimized_memory_store.py b/src/orcapod/stores/optimized_memory_store.py deleted file mode 100644 index 1859113..0000000 --- a/src/orcapod/stores/optimized_memory_store.py +++ /dev/null @@ -1,445 +0,0 @@ -import polars as pl -import pyarrow as pa -import logging -from typing import Any, Dict, List, Tuple, cast -from collections import defaultdict - -# Module-level logger -logger = logging.getLogger(__name__) - - -class ArrowBatchedPolarsDataStore: - """ - Arrow-batched Polars data store that minimizes Arrow<->Polars conversions. - - Key optimizations: - 1. Keep data in Arrow format during batching - 2. Only convert to Polars when consolidating or querying - 3. Batch Arrow tables and concatenate before conversion - 4. Maintain Arrow-based indexing for fast lookups - 5. Lazy Polars conversion only when needed - """ - - def __init__(self, duplicate_entry_behavior: str = "error", batch_size: int = 100): - """ - Initialize the ArrowBatchedPolarsDataStore. - - Args: - duplicate_entry_behavior: How to handle duplicate entry_ids: - - 'error': Raise ValueError when entry_id already exists - - 'overwrite': Replace existing entry with new data - batch_size: Number of records to batch before consolidating - """ - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - - self.duplicate_entry_behavior = duplicate_entry_behavior - self.batch_size = batch_size - - # Arrow batch buffer: {source_key: [(entry_id, arrow_table), ...]} - self._arrow_batches: Dict[str, List[Tuple[str, pa.Table]]] = defaultdict(list) - - # Consolidated Polars store: {source_key: polars_dataframe} - self._polars_store: Dict[str, pl.DataFrame] = {} - - # Entry ID index for fast lookups: {source_key: set[entry_ids]} - self._entry_index: Dict[str, set] = defaultdict(set) - - # Schema cache - self._schema_cache: Dict[str, pa.Schema] = {} - - logger.info( - f"Initialized ArrowBatchedPolarsDataStore with " - f"duplicate_entry_behavior='{duplicate_entry_behavior}', batch_size={batch_size}" - ) - - def _get_source_key(self, source_name: str, source_id: str) -> str: - """Generate key for source storage.""" - return f"{source_name}:{source_id}" - - def _add_entry_id_to_arrow_table(self, table: pa.Table, entry_id: str) -> pa.Table: - """Add entry_id column to Arrow table efficiently.""" - # Create entry_id array with the same length as the table - entry_id_array = pa.array([entry_id] * len(table), type=pa.string()) - - # Add column at the beginning for consistent ordering - return table.add_column(0, "__entry_id", entry_id_array) - - def _consolidate_arrow_batch(self, source_key: str) -> None: - """Consolidate Arrow batch into Polars DataFrame.""" - if source_key not in self._arrow_batches or not self._arrow_batches[source_key]: - return - - logger.debug( - f"Consolidating {len(self._arrow_batches[source_key])} Arrow tables for {source_key}" - ) - - # Prepare all Arrow tables with entry_id columns - arrow_tables_with_id = [] - - for entry_id, arrow_table in self._arrow_batches[source_key]: - table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) - arrow_tables_with_id.append(table_with_id) - - # Concatenate all Arrow tables at once (very fast) - if len(arrow_tables_with_id) == 1: - consolidated_arrow = arrow_tables_with_id[0] - else: - consolidated_arrow = pa.concat_tables(arrow_tables_with_id) - - # Single conversion to Polars - new_polars_df = cast(pl.DataFrame, pl.from_arrow(consolidated_arrow)) - - # Combine with existing Polars DataFrame if it exists - if source_key in self._polars_store: - existing_df = self._polars_store[source_key] - self._polars_store[source_key] = pl.concat([existing_df, new_polars_df]) - else: - self._polars_store[source_key] = new_polars_df - - # Clear the Arrow batch - self._arrow_batches[source_key].clear() - - logger.debug( - f"Consolidated to Polars DataFrame with {len(self._polars_store[source_key])} total rows" - ) - - def _force_consolidation(self, source_key: str) -> None: - """Force consolidation of Arrow batches.""" - if source_key in self._arrow_batches and self._arrow_batches[source_key]: - self._consolidate_arrow_batch(source_key) - - def _get_consolidated_dataframe(self, source_key: str) -> pl.DataFrame | None: - """Get consolidated Polars DataFrame, forcing consolidation if needed.""" - self._force_consolidation(source_key) - return self._polars_store.get(source_key) - - def add_record( - self, - source_name: str, - source_id: str, - entry_id: str, - arrow_data: pa.Table, - ) -> pa.Table: - """ - Add a record to the store using Arrow batching. - - This is the fastest path - no conversions, just Arrow table storage. - """ - source_key = self._get_source_key(source_name, source_id) - - # Check for duplicate entry - if entry_id in self._entry_index[source_key]: - if self.duplicate_entry_behavior == "error": - raise ValueError( - f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - else: - # Handle overwrite: remove from both Arrow batch and Polars store - # Remove from Arrow batch - self._arrow_batches[source_key] = [ - (eid, table) - for eid, table in self._arrow_batches[source_key] - if eid != entry_id - ] - - # Remove from Polars store if it exists - if source_key in self._polars_store: - self._polars_store[source_key] = self._polars_store[ - source_key - ].filter(pl.col("__entry_id") != entry_id) - - # Schema validation (cached) - if source_key in self._schema_cache: - if not self._schema_cache[source_key].equals(arrow_data.schema): - raise ValueError( - f"Schema mismatch for {source_key}. " - f"Expected: {self._schema_cache[source_key]}, " - f"Got: {arrow_data.schema}" - ) - else: - self._schema_cache[source_key] = arrow_data.schema - - # Add to Arrow batch (no conversion yet!) - self._arrow_batches[source_key].append((entry_id, arrow_data)) - self._entry_index[source_key].add(entry_id) - - # Consolidate if batch is full - if len(self._arrow_batches[source_key]) >= self.batch_size: - self._consolidate_arrow_batch(source_key) - - logger.debug(f"Added entry {entry_id} to Arrow batch for {source_key}") - return arrow_data - - def get_record( - self, source_name: str, source_id: str, entry_id: str - ) -> pa.Table | None: - """Get a specific record with optimized lookup.""" - source_key = self._get_source_key(source_name, source_id) - - # Quick existence check - if entry_id not in self._entry_index[source_key]: - return None - - # Check Arrow batch first (most recent data) - for batch_entry_id, arrow_table in self._arrow_batches[source_key]: - if batch_entry_id == entry_id: - return arrow_table - - # Check consolidated Polars store - df = self._get_consolidated_dataframe(source_key) - if df is None: - return None - - # Filter and convert back to Arrow - filtered_df = df.filter(pl.col("__entry_id") == entry_id).drop("__entry_id") - - if filtered_df.height == 0: - return None - - return filtered_df.to_arrow() - - def get_all_records( - self, source_name: str, source_id: str, add_entry_id_column: bool | str = False - ) -> pa.Table | None: - """Retrieve all records as a single Arrow table.""" - source_key = self._get_source_key(source_name, source_id) - - # Force consolidation to include all data - df = self._get_consolidated_dataframe(source_key) - if df is None or df.height == 0: - return None - - # Handle entry_id column - if add_entry_id_column is False: - result_df = df.drop("__entry_id") - elif add_entry_id_column is True: - result_df = df - elif isinstance(add_entry_id_column, str): - result_df = df.rename({"__entry_id": add_entry_id_column}) - else: - result_df = df.drop("__entry_id") - - return result_df.to_arrow() - - def get_all_records_as_polars( - self, source_name: str, source_id: str - ) -> pl.LazyFrame | None: - """Retrieve all records as a Polars LazyFrame.""" - source_key = self._get_source_key(source_name, source_id) - - df = self._get_consolidated_dataframe(source_key) - if df is None or df.height == 0: - return None - - return df.drop("__entry_id").lazy() - - def get_records_by_ids( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """Retrieve records by entry IDs efficiently.""" - # Convert input to list for processing - if isinstance(entry_ids, list): - if not entry_ids: - return None - entry_ids_list = entry_ids - elif isinstance(entry_ids, pl.Series): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_list() - elif isinstance(entry_ids, pa.Array): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_pylist() - else: - raise TypeError(f"entry_ids must be list[str], pl.Series, or pa.Array") - - source_key = self._get_source_key(source_name, source_id) - - # Quick filter using index - existing_entries = [ - entry_id - for entry_id in entry_ids_list - if entry_id in self._entry_index[source_key] - ] - - if not existing_entries and not preserve_input_order: - return None - - # Collect from Arrow batch first - batch_tables = [] - found_in_batch = set() - - for entry_id, arrow_table in self._arrow_batches[source_key]: - if entry_id in entry_ids_list: - table_with_id = self._add_entry_id_to_arrow_table(arrow_table, entry_id) - batch_tables.append(table_with_id) - found_in_batch.add(entry_id) - - # Get remaining from consolidated store - remaining_ids = [eid for eid in existing_entries if eid not in found_in_batch] - - consolidated_tables = [] - if remaining_ids: - df = self._get_consolidated_dataframe(source_key) - if df is not None: - if preserve_input_order: - ordered_df = pl.DataFrame({"__entry_id": entry_ids_list}) - result_df = ordered_df.join(df, on="__entry_id", how="left") - else: - result_df = df.filter(pl.col("__entry_id").is_in(remaining_ids)) - - if result_df.height > 0: - consolidated_tables.append(result_df.to_arrow()) - - # Combine all results - all_tables = batch_tables + consolidated_tables - - if not all_tables: - return None - - # Concatenate Arrow tables - if len(all_tables) == 1: - result_table = all_tables[0] - else: - result_table = pa.concat_tables(all_tables) - - # Handle entry_id column - if add_entry_id_column is False: - # Remove __entry_id column - column_names = result_table.column_names - if "__entry_id" in column_names: - indices = [ - i for i, name in enumerate(column_names) if name != "__entry_id" - ] - result_table = result_table.select(indices) - elif isinstance(add_entry_id_column, str): - # Rename __entry_id column - schema = result_table.schema - new_names = [ - add_entry_id_column if name == "__entry_id" else name - for name in schema.names - ] - result_table = result_table.rename_columns(new_names) - - return result_table - - def get_records_by_ids_as_polars( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - """Retrieve records by entry IDs as Polars LazyFrame.""" - arrow_result = self.get_records_by_ids( - source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order - ) - - if arrow_result is None: - return None - - pl_result = cast(pl.DataFrame, pl.from_arrow(arrow_result)) - - return pl_result.lazy() - - def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: - """Check if entry exists using the index.""" - source_key = self._get_source_key(source_name, source_id) - return entry_id in self._entry_index[source_key] - - def list_entries(self, source_name: str, source_id: str) -> set[str]: - """List all entry IDs using the index.""" - source_key = self._get_source_key(source_name, source_id) - return self._entry_index[source_key].copy() - - def list_sources(self) -> set[tuple[str, str]]: - """List all source combinations.""" - sources = set() - for source_key in self._entry_index.keys(): - if ":" in source_key: - source_name, source_id = source_key.split(":", 1) - sources.add((source_name, source_id)) - return sources - - def force_consolidation(self) -> None: - """Force consolidation of all Arrow batches.""" - for source_key in list(self._arrow_batches.keys()): - self._force_consolidation(source_key) - logger.info("Forced consolidation of all Arrow batches") - - def clear_source(self, source_name: str, source_id: str) -> None: - """Clear all data for a source.""" - source_key = self._get_source_key(source_name, source_id) - - if source_key in self._arrow_batches: - del self._arrow_batches[source_key] - if source_key in self._polars_store: - del self._polars_store[source_key] - if source_key in self._entry_index: - del self._entry_index[source_key] - if source_key in self._schema_cache: - del self._schema_cache[source_key] - - logger.debug(f"Cleared source {source_key}") - - def clear_all(self) -> None: - """Clear all data.""" - self._arrow_batches.clear() - self._polars_store.clear() - self._entry_index.clear() - self._schema_cache.clear() - logger.info("Cleared all data") - - def get_stats(self) -> dict[str, Any]: - """Get comprehensive statistics.""" - total_records = sum(len(entries) for entries in self._entry_index.values()) - total_batched = sum(len(batch) for batch in self._arrow_batches.values()) - total_consolidated = ( - sum(len(df) for df in self._polars_store.values()) - if self._polars_store - else 0 - ) - - source_stats = [] - for source_key in self._entry_index.keys(): - record_count = len(self._entry_index[source_key]) - batched_count = len(self._arrow_batches.get(source_key, [])) - consolidated_count = 0 - - if source_key in self._polars_store: - consolidated_count = len(self._polars_store[source_key]) - - source_stats.append( - { - "source_key": source_key, - "total_records": record_count, - "batched_records": batched_count, - "consolidated_records": consolidated_count, - } - ) - - return { - "total_records": total_records, - "total_sources": len(self._entry_index), - "total_batched": total_batched, - "total_consolidated": total_consolidated, - "batch_size": self.batch_size, - "duplicate_entry_behavior": self.duplicate_entry_behavior, - "source_details": source_stats, - } - - def optimize_for_reads(self) -> None: - """Optimize for read operations by consolidating all batches.""" - logger.info("Optimizing for reads - consolidating all Arrow batches...") - self.force_consolidation() - # Clear Arrow batches to save memory - self._arrow_batches.clear() - logger.info("Optimization complete") From 555a751c2cd808445c0fff9e41055f11c4a9b180 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:45:59 +0000 Subject: [PATCH 52/57] feat: improve pipeline usability with typechecks and convenience attributes --- src/orcapod/core/base.py | 15 ++++++++++++++- src/orcapod/core/pod.py | 19 ++++++++++++++++++- src/orcapod/pipeline/nodes.py | 14 ++++++++++++-- src/orcapod/pipeline/pipeline.py | 26 +++++++++++++++++++++----- src/orcapod/types/core.py | 2 +- 5 files changed, 66 insertions(+), 10 deletions(-) diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 64b99cb..367bc72 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -53,9 +53,22 @@ def post_forward_hook(self, output_stream: "SyncStream", **kwargs) -> "SyncStrea def __call__( self, *streams: "SyncStream", label: str | None = None, **kwargs ) -> "SyncStream": + # check that inputs are stream instances and if it's source, instantiate it + verified_streams = [] + for stream in streams: + if not isinstance(stream, SyncStream): + raise TypeError( + f"Expected SyncStream, got {type(stream).__name__} for stream {stream}" + ) + if isinstance(stream, Source): + # if the stream is a Source, instantiate it + stream = stream() + verified_streams.append(stream) + # Special handling of Source: trigger call on source if passed as stream normalized_streams = [ - stream() if isinstance(stream, Source) else stream for stream in streams + stream() if isinstance(stream, Source) else stream + for stream in verified_streams ] pre_processed_streams = self.pre_forward_hook(*normalized_streams, **kwargs) diff --git a/src/orcapod/core/pod.py b/src/orcapod/core/pod.py index d64bafa..92d8568 100644 --- a/src/orcapod/core/pod.py +++ b/src/orcapod/core/pod.py @@ -8,7 +8,10 @@ ) from orcapod.types import Packet, Tag, TypeSpec, default_registry -from orcapod.types.typespec_utils import extract_function_typespecs +from orcapod.types.typespec_utils import ( + extract_function_typespecs, + check_typespec_compatibility, +) from orcapod.types.packets import PacketConverter from orcapod.hashing import ( @@ -221,6 +224,20 @@ def __init__( self.function_output_typespec, self.registry ) + def forward(self, *streams: SyncStream, **kwargs) -> SyncStream: + assert len(streams) == 1, ( + "Only one stream is supported in forward() of FunctionPod" + ) + stream = streams[0] + _, packet_typespec = stream.types(trigger_run=False) + if packet_typespec is not None and not check_typespec_compatibility( + packet_typespec, self.function_input_typespec + ): + raise TypeError( + f"Input packet types {packet_typespec} is not compatible with the function's expected input types {self.function_input_typespec}" + ) + return super().forward(*streams, **kwargs) + def get_function_typespecs(self) -> tuple[TypeSpec, TypeSpec]: return self.function_input_typespec, self.function_output_typespec diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index fabb664..07d9eb4 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -169,7 +169,10 @@ def update_cached_values(self): self.kernel_hash = self.kernel_hasher.hash_to_hex( self.kernel, prefix_hasher_id=True ) - self.tag_keys, self.packet_keys = self.keys(trigger_run=False) + tag_keys, packet_keys = self.keys(trigger_run=False) + self.tag_keys = tuple(tag_keys) if tag_keys is not None else None + self.packet_keys = tuple(packet_keys) if packet_keys is not None else None + self.tag_typespec, self.packet_typespec = self.types(trigger_run=False) if self.tag_typespec is None or self.packet_typespec is None: raise ValueError( @@ -248,7 +251,14 @@ def output_iterator_completion_hook(self) -> None: @property def lazy_df(self) -> pl.LazyFrame | None: - return self.output_store.get_all_records_as_polars(self.store_path) + lazydf = self.output_store.get_all_records_as_polars(self.store_path) + if lazydf is None: + return None + if self.tag_keys is None or self.packet_keys is None: + raise ValueError( + "CachedKernelWrapper has no tag keys or packet keys defined, and currently this is not supported" + ) + return lazydf.select(self.tag_keys + self.packet_keys) @property def df(self) -> pl.DataFrame | None: diff --git a/src/orcapod/pipeline/pipeline.py b/src/orcapod/pipeline/pipeline.py index 2a7d86e..1fb5236 100644 --- a/src/orcapod/pipeline/pipeline.py +++ b/src/orcapod/pipeline/pipeline.py @@ -53,7 +53,7 @@ def __init__( self.pipeline_store = pipeline_store self.results_store = results_store - self.labels_to_nodes = {} + self.nodes = {} self.auto_compile = auto_compile self._dirty = False self._ordered_nodes = [] # Track order of invocations @@ -167,7 +167,8 @@ def compile(self): nodes[0].label = label labels_to_nodes[label] = nodes[0] - self.labels_to_nodes = labels_to_nodes + # store as pipeline's nodes attribute + self.nodes = labels_to_nodes self._dirty = False return node_lut, edge_lut, proposed_labels, labels_to_nodes @@ -178,13 +179,28 @@ def __exit__(self, exc_type, exc_val, ext_tb): def __getattr__(self, item: str) -> Any: """Allow direct access to pipeline attributes""" - if item in self.labels_to_nodes: - return self.labels_to_nodes[item] + if item in self.nodes: + return self.nodes[item] raise AttributeError(f"Pipeline has no attribute '{item}'") def __dir__(self): # Include both regular attributes and dynamic ones - return list(super().__dir__()) + list(self.labels_to_nodes.keys()) + return list(super().__dir__()) + list(self.nodes.keys()) + + def rename(self, old_name: str, new_name: str) -> None: + """ + Rename a node in the pipeline. + This will update the label and the internal mapping. + """ + if old_name not in self.nodes: + raise KeyError(f"Node '{old_name}' does not exist in the pipeline.") + if new_name in self.nodes: + raise KeyError(f"Node '{new_name}' already exists in the pipeline.") + node = self.nodes[old_name] + del self.nodes[old_name] + node.label = new_name + self.nodes[new_name] = node + logger.info(f"Node '{old_name}' renamed to '{new_name}'") def run(self, full_sync: bool = False) -> None: """ diff --git a/src/orcapod/types/core.py b/src/orcapod/types/core.py index 62c100d..22491ae 100644 --- a/src/orcapod/types/core.py +++ b/src/orcapod/types/core.py @@ -15,7 +15,7 @@ # an (optional) string or a collection of (optional) string values # Note that TagValue can be nested, allowing for an arbitrary depth of nested lists -TagValue: TypeAlias = str | None | Collection["TagValue"] +TagValue: TypeAlias = int | str | None | Collection["TagValue"] # the top level tag is a mapping from string keys to values that can be a string or # an arbitrary depth of nested list of strings or None From 083134b050c2c443ad65a8100f4e90177c00634d Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:46:11 +0000 Subject: [PATCH 53/57] fix: use new store name --- src/orcapod/stores/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orcapod/stores/__init__.py b/src/orcapod/stores/__init__.py index 281874b..1114c11 100644 --- a/src/orcapod/stores/__init__.py +++ b/src/orcapod/stores/__init__.py @@ -1,5 +1,5 @@ from .types import DataStore, ArrowDataStore -from .arrow_data_stores import MockArrowDataStore, SimpleInMemoryDataStore +from .arrow_data_stores import MockArrowDataStore, SimpleParquetDataStore from .dict_data_stores import DirDataStore, NoOpDataStore from .safe_dir_data_store import SafeDirDataStore @@ -10,5 +10,5 @@ "SafeDirDataStore", "NoOpDataStore", "MockArrowDataStore", - "SimpleInMemoryDataStore", + "SimpleParquetDataStore", ] From 7e33bae162e2296ca1f2dff366860ec942780261 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:46:24 +0000 Subject: [PATCH 54/57] test: update to use new package name --- tests/test_store/test_transfer_data_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_store/test_transfer_data_store.py b/tests/test_store/test_transfer_data_store.py index 21ed4c9..4721691 100644 --- a/tests/test_store/test_transfer_data_store.py +++ b/tests/test_store/test_transfer_data_store.py @@ -7,7 +7,7 @@ from orcapod.hashing.types import LegacyPacketHasher from orcapod.stores.dict_data_stores import DirDataStore, NoOpDataStore -from orcapod.stores.transfer_data_store import TransferDataStore +from orcapod.stores.dict_transfer_data_store import TransferDataStore class MockPacketHasher(LegacyPacketHasher): From 5641810382d8769ff925e3b7ba1ea8e156d4aaf5 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 21:50:33 +0000 Subject: [PATCH 55/57] fix: wrong import --- src/orcapod/hashing/arrow_hashers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index a7b5a01..465b29b 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -4,11 +4,11 @@ import polars as pl import json from orcapod.hashing.types import SemanticTypeHasher, StringCacher -from orcapod.hashing import arrow_serialization_old +from orcapod.hashing import arrow_serialization from collections.abc import Callable SERIALIZATION_METHOD_LUT: dict[str, Callable[[pa.Table], bytes]] = { - "logical": arrow_serialization_old.serialize_table_logical, + "logical": arrow_serialization.serialize_table_logical, } From c66920c13613cb347d661b820c89041aaed45b91 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 3 Jul 2025 22:12:05 +0000 Subject: [PATCH 56/57] doc: handle typing corner cases --- src/orcapod/stores/arrow_data_stores.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/orcapod/stores/arrow_data_stores.py b/src/orcapod/stores/arrow_data_stores.py index 93a2400..0a9a7e9 100644 --- a/src/orcapod/stores/arrow_data_stores.py +++ b/src/orcapod/stores/arrow_data_stores.py @@ -1841,6 +1841,8 @@ def get_all_records( df = self.get_all_records_as_polars( source_name, source_id, add_entry_id_column=add_entry_id_column ) + if df is None: + return None return df.collect().to_arrow() def get_all_records_as_polars( @@ -1917,9 +1919,9 @@ def get_records_by_ids( elif isinstance(entry_ids, pa.Array): if len(entry_ids) == 0: return None - entry_ids_series = pl.from_arrow(pa.table({"entry_id": entry_ids}))[ - "entry_id" - ] + entry_ids_series: pl.Series = pl.from_arrow( + pa.table({"entry_id": entry_ids}) + )["entry_id"] # type: ignore else: raise TypeError( f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" @@ -1993,7 +1995,8 @@ def get_records_by_ids_as_polars( return None # Convert to Polars LazyFrame - return pl.from_arrow(arrow_result).lazy() + df = cast(pl.DataFrame, pl.from_arrow(arrow_result)) + return df.lazy() def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: """Check if a specific entry exists.""" From 608428f719c5cd513ba5b6171c14c67723023cd7 Mon Sep 17 00:00:00 2001 From: Brian Arnold Date: Sun, 6 Jul 2025 03:59:56 +0000 Subject: [PATCH 57/57] Add ListSource and DataFrameSource --- src/orcapod/core/sources.py | 316 +++++++++++++++++++++++++++++++++++- 1 file changed, 315 insertions(+), 1 deletion(-) diff --git a/src/orcapod/core/sources.py b/src/orcapod/core/sources.py index b1dca7d..3aeffe2 100644 --- a/src/orcapod/core/sources.py +++ b/src/orcapod/core/sources.py @@ -2,7 +2,7 @@ from os import PathLike from pathlib import Path from typing import Any, Literal - +import pandas as pd import polars as pl from orcapod.core.base import Source @@ -202,3 +202,317 @@ def keys( def computed_label(self) -> str | None: return self.stream.label + + +class DataFrameSource(Source): + """ + A stream source that sources data from a pandas DataFrame. + + For each row in the DataFrame, yields a tuple containing: + - A tag generated either by the provided tag_function or defaulting to the row index + - A packet containing values from specified columns as key-value pairs + + Parameters + ---------- + columns : list[str] + List of column names to include in the packet. These will serve as the keys + in the packet, with the corresponding row values as the packet values. + data : pd.DataFrame + The pandas DataFrame to source data from + tag_function : Callable[[pd.Series, int], Tag] | None, default=None + Optional function to generate a tag from a DataFrame row and its index. + The function receives the row as a pandas Series and the row index as arguments. + If None, uses the row index in a dict with key 'row_index' + tag_function_hash_mode : Literal["content", "signature", "name"], default="name" + How to hash the tag function for identity purposes + expected_tag_keys : Collection[str] | None, default=None + Expected tag keys for the stream + label : str | None, default=None + Optional label for the source + + Examples + -------- + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... 'file_path': ['/path/to/file1.txt', '/path/to/file2.txt'], + ... 'metadata_path': ['/path/to/meta1.json', '/path/to/meta2.json'], + ... 'sample_id': ['sample_1', 'sample_2'] + ... }) + >>> # Use sample_id column for tags and include file paths in packets + >>> source = DataFrameSource( + ... columns=['file_path', 'metadata_path'], + ... data=df, + ... tag_function=lambda row, idx: {'sample_id': row['sample_id']} + ... ) + >>> # Use default row index tagging + >>> source = DataFrameSource(['file_path', 'metadata_path'], df) + """ + + @staticmethod + def default_tag_function(row: pd.Series, idx: int) -> Tag: + return {"row_index": idx} + + def __init__( + self, + columns: list[str], + data: pd.DataFrame, + tag_function: Callable[[pd.Series, int], Tag] | None = None, + label: str | None = None, + tag_function_hash_mode: Literal["content", "signature", "name"] = "name", + expected_tag_keys: Collection[str] | None = None, + **kwargs, + ) -> None: + super().__init__(label=label, **kwargs) + self.columns = columns + self.dataframe = data + + # Validate that all specified columns exist in the DataFrame + missing_columns = set(columns) - set(data.columns) + if missing_columns: + raise ValueError(f"Columns not found in DataFrame: {missing_columns}") + + if tag_function is None: + tag_function = self.__class__.default_tag_function + # If using default tag function and no explicit expected_tag_keys, set to default + if expected_tag_keys is None: + expected_tag_keys = ["row_index"] + + self.expected_tag_keys = expected_tag_keys + self.tag_function = tag_function + self.tag_function_hash_mode = tag_function_hash_mode + + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "DataFrameSource does not support forwarding streams. " + "It generates its own stream from the DataFrame." + ) + + def generator() -> Iterator[tuple[Tag, Packet]]: + for idx, row in self.dataframe.iterrows(): + tag = self.tag_function(row, idx) + packet = {col: row[col] for col in self.columns} + yield tag, packet + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + return f"DataFrameSource(cols={self.columns}, rows={len(self.dataframe)})" + + def identity_structure(self, *streams: SyncStream) -> Any: + hash_function_kwargs = {} + if self.tag_function_hash_mode == "content": + # if using content hash, exclude few + hash_function_kwargs = { + "include_name": False, + "include_module": False, + "include_declaration": False, + } + + tag_function_hash = hash_function( + self.tag_function, + function_hash_mode=self.tag_function_hash_mode, + hash_kwargs=hash_function_kwargs, + ) + + # Convert DataFrame to hashable representation + df_subset = self.dataframe[self.columns] + df_content = df_subset.to_dict('records') + df_hashable = tuple(tuple(sorted(record.items())) for record in df_content) + + return ( + self.__class__.__name__, + tuple(self.columns), + df_hashable, + tag_function_hash, + ) + tuple(streams) + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the stream. The keys are the names of the packets + in the stream. The keys are used to identify the packets in the stream. + If expected_keys are provided, they will be used instead of the default keys. + """ + if len(streams) != 0: + raise ValueError( + "DataFrameSource does not support forwarding streams. " + "It generates its own stream from the DataFrame." + ) + + if self.expected_tag_keys is not None: + return tuple(self.expected_tag_keys), tuple(self.columns) + return super().keys(trigger_run=trigger_run) + + def claims_unique_tags( + self, *streams: "SyncStream", trigger_run: bool = True + ) -> bool | None: + if len(streams) != 0: + raise ValueError( + "DataFrameSource does not support forwarding streams. " + "It generates its own stream from the DataFrame." + ) + # Claim uniqueness only if the default tag function is used + if self.tag_function == self.__class__.default_tag_function: + return True + # Otherwise, delegate to the base class + return super().claims_unique_tags(trigger_run=trigger_run) + + +class ListSource(Source): + """ + A stream source that sources data from a list of elements. + + For each element in the list, yields a tuple containing: + - A tag generated either by the provided tag_function or defaulting to the element index + - A packet containing the element under the provided name key + + Parameters + ---------- + name : str + The key name under which each list element will be stored in the packet + data : list[Any] + The list of elements to source data from + tag_function : Callable[[Any, int], Tag] | None, default=None + Optional function to generate a tag from a list element and its index. + The function receives the element and the index as arguments. + If None, uses the element index in a dict with key 'element_index' + tag_function_hash_mode : Literal["content", "signature", "name"], default="name" + How to hash the tag function for identity purposes + expected_tag_keys : Collection[str] | None, default=None + Expected tag keys for the stream + label : str | None, default=None + Optional label for the source + + Examples + -------- + >>> # Simple list of file names + >>> file_list = ['/path/to/file1.txt', '/path/to/file2.txt', '/path/to/file3.txt'] + >>> source = ListSource('file_path', file_list) + >>> + >>> # Custom tag function using filename stems + >>> from pathlib import Path + >>> source = ListSource( + ... 'file_path', + ... file_list, + ... tag_function=lambda elem, idx: {'file_name': Path(elem).stem} + ... ) + >>> + >>> # List of sample IDs + >>> samples = ['sample_001', 'sample_002', 'sample_003'] + >>> source = ListSource( + ... 'sample_id', + ... samples, + ... tag_function=lambda elem, idx: {'sample': elem} + ... ) + """ + + @staticmethod + def default_tag_function(element: Any, idx: int) -> Tag: + return {"element_index": idx} + + def __init__( + self, + name: str, + data: list[Any], + tag_function: Callable[[Any, int], Tag] | None = None, + label: str | None = None, + tag_function_hash_mode: Literal["content", "signature", "name"] = "name", + expected_tag_keys: Collection[str] | None = None, + **kwargs, + ) -> None: + super().__init__(label=label, **kwargs) + self.name = name + self.elements = list(data) # Create a copy to avoid external modifications + + if tag_function is None: + tag_function = self.__class__.default_tag_function + # If using default tag function and no explicit expected_tag_keys, set to default + if expected_tag_keys is None: + expected_tag_keys = ["element_index"] + + self.expected_tag_keys = expected_tag_keys + self.tag_function = tag_function + self.tag_function_hash_mode = tag_function_hash_mode + + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 0: + raise ValueError( + "ListSource does not support forwarding streams. " + "It generates its own stream from the list elements." + ) + + def generator() -> Iterator[tuple[Tag, Packet]]: + for idx, element in enumerate(self.elements): + tag = self.tag_function(element, idx) + packet = {self.name: element} + yield tag, packet + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + return f"ListSource({self.name}, {len(self.elements)} elements)" + + def identity_structure(self, *streams: SyncStream) -> Any: + hash_function_kwargs = {} + if self.tag_function_hash_mode == "content": + # if using content hash, exclude few + hash_function_kwargs = { + "include_name": False, + "include_module": False, + "include_declaration": False, + } + + tag_function_hash = hash_function( + self.tag_function, + function_hash_mode=self.tag_function_hash_mode, + hash_kwargs=hash_function_kwargs, + ) + + # Convert list to hashable representation + # Handle potentially unhashable elements by converting to string + try: + elements_hashable = tuple(self.elements) + except TypeError: + # If elements are not hashable, convert to string representation + elements_hashable = tuple(str(elem) for elem in self.elements) + + return ( + self.__class__.__name__, + self.name, + elements_hashable, + tag_function_hash, + ) + tuple(streams) + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the stream. The keys are the names of the packets + in the stream. The keys are used to identify the packets in the stream. + If expected_keys are provided, they will be used instead of the default keys. + """ + if len(streams) != 0: + raise ValueError( + "ListSource does not support forwarding streams. " + "It generates its own stream from the list elements." + ) + + if self.expected_tag_keys is not None: + return tuple(self.expected_tag_keys), (self.name,) + return super().keys(trigger_run=trigger_run) + + def claims_unique_tags( + self, *streams: "SyncStream", trigger_run: bool = True + ) -> bool | None: + if len(streams) != 0: + raise ValueError( + "ListSource does not support forwarding streams. " + "It generates its own stream from the list elements." + ) + # Claim uniqueness only if the default tag function is used + if self.tag_function == self.__class__.default_tag_function: + return True + # Otherwise, delegate to the base class + return super().claims_unique_tags(trigger_run=trigger_run)