Skip to content

Commit 9bbc2a0

Browse files
authored
ENH: Make attributes mandatory on PNode and PProvisionalNode (#787)
1 parent 9bdf2c3 commit 9bbc2a0

8 files changed

Lines changed: 43 additions & 36 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
77

88
## Unreleased
99

10+
- {pull}`787` makes the `attributes` field mandatory on `PNode` and
11+
`PProvisionalNode`, and preserves existing node attributes when loading entries from
12+
the data catalog.
1013
- {pull}`744` Removed the direct dependency on attrs and migrated internal models to
1114
dataclasses.
1215
- {pull}`766` moves runtime profiling persistence from SQLite to a JSON snapshot plus

src/_pytask/collect.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from _pytask.node_protocols import PPathNode
3434
from _pytask.node_protocols import PProvisionalNode
3535
from _pytask.node_protocols import PTask
36-
from _pytask.node_protocols import warn_about_upcoming_attributes_field_on_nodes
3736
from _pytask.nodes import DirectoryNode
3837
from _pytask.nodes import PathNode
3938
from _pytask.nodes import PythonNode
@@ -419,9 +418,6 @@ def pytask_collect_node( # noqa: C901, PLR0912
419418
"""
420419
node = node_info.value
421420

422-
if isinstance(node, (PNode, PProvisionalNode)) and not hasattr(node, "attributes"):
423-
warn_about_upcoming_attributes_field_on_nodes()
424-
425421
if isinstance(node, DirectoryNode):
426422
if node.root_dir is None:
427423
node.root_dir = path

src/_pytask/data_catalog.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from _pytask.models import NodeInfo
2222
from _pytask.node_protocols import PNode
2323
from _pytask.node_protocols import PProvisionalNode
24-
from _pytask.node_protocols import warn_about_upcoming_attributes_field_on_nodes
2524
from _pytask.nodes import PickleNode
2625
from _pytask.pluginmanager import storage
2726
from _pytask.session import Session
@@ -102,10 +101,7 @@ def __post_init__(self) -> None:
102101
# Initialize the data catalog with persisted nodes from previous runs.
103102
for path in self.path.glob("*-node.pkl"):
104103
node = pickle.loads(path.read_bytes()) # noqa: S301
105-
if not hasattr(node, "attributes"):
106-
warn_about_upcoming_attributes_field_on_nodes()
107-
else:
108-
node.attributes = {DATA_CATALOG_NAME_FIELD: self.name}
104+
node.attributes[DATA_CATALOG_NAME_FIELD] = self.name
109105
self._entries[node.name] = node
110106

111107
def __getitem__(self, name: str) -> PNode | PProvisionalNode:
@@ -150,7 +146,4 @@ def add(self, name: str, node: PNode | PProvisionalNode | Any = None) -> None:
150146
self._entries[name] = collected_node
151147

152148
node = self._entries[name]
153-
if hasattr(node, "attributes"):
154-
node.attributes[DATA_CATALOG_NAME_FIELD] = self.name # ty: ignore[invalid-assignment]
155-
else:
156-
warn_about_upcoming_attributes_field_on_nodes()
149+
node.attributes[DATA_CATALOG_NAME_FIELD] = self.name

src/_pytask/node_protocols.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import warnings
43
from typing import TYPE_CHECKING
54
from typing import Any
65
from typing import Protocol
@@ -22,6 +21,7 @@ class PNode(Protocol):
2221
"""Protocol for nodes."""
2322

2423
name: str
24+
attributes: dict[Any, Any]
2525

2626
@property
2727
def signature(self) -> str:
@@ -117,6 +117,7 @@ class PProvisionalNode(Protocol):
117117
"""
118118

119119
name: str
120+
attributes: dict[Any, Any]
120121

121122
@property
122123
def signature(self) -> str:
@@ -139,14 +140,3 @@ def load(self, is_product: bool = False) -> Any: # pragma: no cover
139140

140141
def collect(self) -> list[Any]:
141142
"""Collect the objects that are defined by the provisional nodes."""
142-
143-
144-
def warn_about_upcoming_attributes_field_on_nodes() -> None:
145-
warnings.warn(
146-
"PNode and PProvisionalNode will require an 'attributes' field starting "
147-
"with pytask v0.6.0. It is a dictionary with any type of key and values "
148-
"similar to PTask. See https://tinyurl.com/pytask-custom-nodes for more "
149-
"information about adjusting your custom nodes.",
150-
stacklevel=1,
151-
category=FutureWarning,
152-
)

tests/test_collect_command.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,18 @@ def task_example(
468468
def test_node_protocol_for_custom_nodes(runner, tmp_path):
469469
source = """
470470
from typing import Annotated
471+
from typing import Any
471472
from pytask import Product
472473
from dataclasses import dataclass
474+
from dataclasses import field
473475
from pathlib import Path
474476
475477
@dataclass
476478
class CustomNode:
477479
name: str
478480
value: str
479481
signature: str = "id"
482+
attributes: dict[Any, Any] = field(default_factory=dict)
480483
481484
def state(self):
482485
return self.value
@@ -673,9 +676,11 @@ def task_example(
673676
def test_collect_custom_node_receives_default_name(runner, tmp_path):
674677
source = """
675678
from typing import Annotated
679+
from typing import Any
676680
677681
class CustomNode:
678682
name: str = ""
683+
attributes: dict[Any, Any] = {}
679684
680685
def state(self): return None
681686
def signature(self): return "signature"

tests/test_data_catalog.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import hashlib
4+
import pickle
35
import sys
46
import textwrap
57
from pathlib import Path
68

79
import pytest
810

11+
from _pytask.data_catalog_utils import DATA_CATALOG_NAME_FIELD
912
from pytask import DataCatalog
1013
from pytask import ExitCode
1114
from pytask import PathNode
@@ -198,6 +201,25 @@ def test_adding_a_python_node():
198201
assert isinstance(data_catalog["node"], PythonNode)
199202

200203

204+
def test_reloading_data_catalog_preserves_node_attributes(tmp_path):
205+
data_catalog = DataCatalog(_instance_path=tmp_path)
206+
_ = data_catalog["node"]
207+
assert data_catalog.path is not None
208+
209+
filename = hashlib.sha256(b"node").hexdigest()
210+
path_to_node = data_catalog.path / f"{filename}-node.pkl"
211+
212+
node = pickle.loads(path_to_node.read_bytes()) # noqa: S301
213+
node.attributes["custom"] = "value"
214+
path_to_node.write_bytes(pickle.dumps(node))
215+
216+
reloaded_data_catalog = DataCatalog(_instance_path=tmp_path)
217+
reloaded_node = reloaded_data_catalog["node"]
218+
219+
assert reloaded_node.attributes["custom"] == "value"
220+
assert reloaded_node.attributes[DATA_CATALOG_NAME_FIELD] == "default"
221+
222+
201223
def test_use_data_catalog_with_provisional_node(runner, tmp_path):
202224
source = """
203225
from pathlib import Path

tests/test_execute.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,15 @@ def test_custom_node_as_product(runner, tmp_path, product_def, return_def):
421421
from typing import Any
422422
from typing import Annotated
423423
from dataclasses import dataclass
424+
from dataclasses import field
424425
from pytask import Product
425426
426427
@dataclass
427428
class PickleNode:
428429
path: Path
429430
name: str = ""
430431
signature: str = "id"
432+
attributes: dict[Any, Any] = field(default_factory=dict)
431433
432434
def state(self) -> str | None:
433435
if self.path.exists():
@@ -751,13 +753,15 @@ def test_errors_during_loading_nodes_have_info(runner, tmp_path):
751753
from pathlib import Path
752754
from typing import Any
753755
from dataclasses import dataclass
756+
from dataclasses import field
754757
import pickle
755758
756759
@dataclass
757760
class PickleNode:
758761
name: str
759762
path: Path
760763
signature: str = "id"
764+
attributes: dict[Any, Any] = field(default_factory=dict)
761765
762766
def state(self) -> str | None:
763767
if self.path.exists():

tests/test_node_protocols.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,10 @@ def task_example(
9292
assert tmp_path.joinpath("out.txt").read_text() == "text"
9393

9494

95-
def test_node_protocol_for_custom_nodes_adding_attributes(runner, tmp_path):
95+
def test_node_protocol_for_custom_nodes_requires_attributes(runner, tmp_path):
9696
source = """
9797
from typing import Annotated
98-
from pytask import Product
9998
from dataclasses import dataclass
100-
from pathlib import Path
10199
102100
@dataclass
103101
class CustomNode:
@@ -114,15 +112,11 @@ def load(self, is_product):
114112
def save(self, value):
115113
self.value = value
116114
117-
def task_example(
118-
data = CustomNode("custom", "text"),
119-
out: Annotated[Path, Product] = Path("out.txt"),
120-
) -> None:
121-
out.write_text(data)
115+
def task_example() -> Annotated[str, CustomNode("custom", "text")]:
116+
return "text"
122117
"""
123118
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
124119

125120
result = runner.invoke(cli, [tmp_path.as_posix()])
126-
assert result.exit_code == ExitCode.OK
127-
assert tmp_path.joinpath("out.txt").read_text() == "text"
128-
assert "FutureWarning" in result.output
121+
assert result.exit_code == ExitCode.COLLECTION_FAILED
122+
assert "does not follow the 'pytask.PNode' protocol" in result.output

0 commit comments

Comments
 (0)