Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
repos:
- repo: https://github.com/tsvikas/sync-with-uv
rev: v0.4.0 # replace with the latest version
hooks:
- id: sync-with-uv
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
hooks:
- id: ruff-format
types_or: [ python, pyi ]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: trailing-whitespace
types_or: [ python, pyi ]
- id: end-of-file-fixer
types_or: [ python, pyi ]
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ dev = [
"ipywidgets>=8.1.7",
"jsonschema>=4.25.0",
"minio>=7.2.16",
"pre-commit>=4.4.0",
"pre-commit-hooks>=6.0.0",
"pyarrow-stubs>=20.0.0.20250716",
"pygraphviz>=1.14",
"pyiceberg>=0.9.1",
Expand All @@ -62,7 +64,7 @@ dev = [
"pytest-cov>=6.1.1",
"ray[default]==2.48.0",
"redis>=6.2.0",
"ruff>=0.11.11",
"ruff>=0.14.4",
"sphinx>=8.2.3",
"tqdm>=4.67.1",
]
1 change: 0 additions & 1 deletion src/orcapod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .pipeline import Pipeline



no_tracking = DEFAULT_TRACKER_MANAGER.no_tracking

__all__ = [
Expand Down
6 changes: 2 additions & 4 deletions src/orcapod/core/pods.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def function_pod(
"""

def decorator(func: Callable) -> CallableWithPod:

if func.__name__ == "<lambda>":
raise ValueError("Lambda functions cannot be used with function_pod")

Expand All @@ -276,6 +275,7 @@ def wrapper(*args, **kwargs):
)
setattr(wrapper, "pod", pod)
return cast(CallableWithPod, wrapper)

return decorator


Expand Down Expand Up @@ -496,9 +496,7 @@ async def async_call(
if execution_engine is not None:
# use the provided execution engine to run the function
values = await execution_engine.submit_async(
self.function,
fn_kwargs=input_dict,
**(execution_engine_opts or {})
self.function, fn_kwargs=input_dict, **(execution_engine_opts or {})
)
else:
values = self.function(**input_dict)
Expand Down
4 changes: 3 additions & 1 deletion src/orcapod/core/streams/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def flow(
def _repr_html_(self) -> str:
df = self.as_polars_df()
# reorder columns
new_column_order = [c for c in df.columns if c in self.tag_keys()] + [c for c in df.columns if c not in self.tag_keys()]
new_column_order = [c for c in df.columns if c in self.tag_keys()] + [
c for c in df.columns if c not in self.tag_keys()
]
df = df[new_column_order]
tag_map = {t: f"*{t}" for t in self.tag_keys()}
# TODO: construct repr html better
Expand Down
3 changes: 2 additions & 1 deletion src/orcapod/core/streams/lazy_pod_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def as_table(
# TODO: verify that order will be preserved
for tag, packet in self.iter_packets(
execution_engine=execution_engine or self.execution_engine,
execution_engine_opts=execution_engine_opts or self._execution_engine_opts,
execution_engine_opts=execution_engine_opts
or self._execution_engine_opts,
):
content_hashes.append(packet.content_hash().to_string())
self._cached_content_hash_column = pa.array(
Expand Down
18 changes: 9 additions & 9 deletions src/orcapod/core/streams/pod_node_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ async def run_async(
This is typically called before iterating over the packets.
"""
if self._cached_output_packets is None:
cached_results, missing = self._identify_existing_and_missing_entries(*args,
cached_results, missing = self._identify_existing_and_missing_entries(
*args,
execution_engine=execution_engine,
execution_engine_opts=execution_engine_opts,
**kwargs,
Expand All @@ -90,6 +91,7 @@ async def run_async(
pending_calls.append(pending)

import asyncio

completed_calls = await asyncio.gather(*pending_calls)
for result in completed_calls:
cached_results.append(result)
Expand All @@ -99,12 +101,14 @@ async def run_async(
self._set_modified_time()
self.pod_node.flush()

def _identify_existing_and_missing_entries(self,
*args: Any,
def _identify_existing_and_missing_entries(
self,
*args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
**kwargs: Any) -> tuple[list[tuple[cp.Tag, cp.Packet|None]], pa.Table | None]:
cached_results: list[tuple[cp.Tag, cp.Packet|None]] = []
**kwargs: Any,
) -> tuple[list[tuple[cp.Tag, cp.Packet | None]], pa.Table | None]:
cached_results: list[tuple[cp.Tag, cp.Packet | None]] = []

# identify all entries in the input stream for which we still have not computed packets
if len(args) > 0 or len(kwargs) > 0:
Expand Down Expand Up @@ -177,8 +181,6 @@ def _identify_existing_and_missing_entries(self,
for tag, packet in existing_stream.iter_packets():
cached_results.append((tag, packet))



return cached_results, missing

def run(
Expand Down Expand Up @@ -230,7 +232,6 @@ def run(
)
cached_results.append((tag, output_packet))


# reset the cache and set new results
self.clear_cache()
self._cached_output_packets = cached_results
Expand Down Expand Up @@ -276,7 +277,6 @@ def iter_packets(
self._cached_output_packets = cached_results
self._set_modified_time()


def keys(
self, include_system_tags: bool = False
) -> tuple[tuple[str, ...], tuple[str, ...]]:
Expand Down
4 changes: 2 additions & 2 deletions src/orcapod/hashing/string_cachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _init_database(self) -> None:
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_last_accessed
CREATE INDEX IF NOT EXISTS idx_last_accessed
ON cache_entries(last_accessed)
""")
conn.commit()
Expand All @@ -330,7 +330,7 @@ def _load_from_database(self) -> None:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
SELECT key, value FROM cache_entries
SELECT key, value FROM cache_entries
ORDER BY last_accessed DESC
""")

Expand Down
7 changes: 4 additions & 3 deletions src/orcapod/pipeline/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def run_in_thread():
return asyncio.run(async_func(*args, **kwargs))




class GraphNode:
def __init__(self, label: str, id: int, kernel_type: str):
self.label = label
Expand Down Expand Up @@ -230,7 +228,10 @@ def run(
may implement more efficient graph traversal algorithms.
"""
import networkx as nx
if run_async is True and (execution_engine is None or not execution_engine.supports_async):

if run_async is True and (
execution_engine is None or not execution_engine.supports_async
):
raise ValueError(
"Cannot run asynchronously with an execution engine that does not support async."
)
Expand Down
5 changes: 1 addition & 4 deletions src/orcapod/pipeline/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(
def execution_engine_opts(self) -> dict[str, Any]:
return self._execution_engine_opts.copy()

@execution_engine_opts.setter
@execution_engine_opts.setter
def execution_engine_opts(self, opts: dict[str, Any]) -> None:
self._execution_engine_opts = opts

Expand Down Expand Up @@ -322,7 +322,6 @@ def call(
if execution_engine_opts is not None:
combined_execution_engine_opts.update(execution_engine_opts)


tag, output_packet = super().call(
tag,
packet,
Expand Down Expand Up @@ -362,12 +361,10 @@ async def async_call(
if record_id is None:
record_id = self.get_record_id(packet, execution_engine_hash)


combined_execution_engine_opts = self.execution_engine_opts
if execution_engine_opts is not None:
combined_execution_engine_opts.update(execution_engine_opts)


tag, output_packet = await super().async_call(
tag,
packet,
Expand Down
1 change: 1 addition & 0 deletions src/orcapod/protocols/core_protocols/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ExecutionEngine(Protocol):
"local", "threadpool", "processpool", or "ray" and is used for logging
and diagnostics.
"""

@property
def supports_async(self) -> bool:
"""Indicate whether this engine supports async execution."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hashing/test_sqlite_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_database_initialization():
# Check that table exists with correct schema
with sqlite3.connect(db_file) as conn:
cursor = conn.execute("""
SELECT sql FROM sqlite_master
SELECT sql FROM sqlite_master
WHERE type='table' AND name='cache_entries'
""")
schema = cursor.fetchone()[0]
Expand All @@ -58,7 +58,7 @@ def test_database_initialization():

# Check that index exists
cursor = conn.execute("""
SELECT name FROM sqlite_master
SELECT name FROM sqlite_master
WHERE type='index' AND name='idx_last_accessed'
""")
assert cursor.fetchone() is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hashing/test_string_cacher/test_sqlite_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_database_initialization():
# Check that table exists with correct schema
with sqlite3.connect(db_file) as conn:
cursor = conn.execute("""
SELECT sql FROM sqlite_master
SELECT sql FROM sqlite_master
WHERE type='table' AND name='cache_entries'
""")
schema = cursor.fetchone()[0]
Expand All @@ -58,7 +58,7 @@ def test_database_initialization():

# Check that index exists
cursor = conn.execute("""
SELECT name FROM sqlite_master
SELECT name FROM sqlite_master
WHERE type='index' AND name='idx_last_accessed'
""")
assert cursor.fetchone() is not None
Expand Down
Loading
Loading