Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flowrep/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _check_inputs_with_defaults_subset_of_inputs(self) -> Self:
def validate_internal_data_completeness(self):
return self

def __call__(self, *args, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} is not a callable recipe type"
)


class RestrictedParamKind(StrEnum):
"""
Expand Down
5 changes: 5 additions & 0 deletions flowrep/models/nodes/atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Literal

import pydantic
from pyiron_snippets import retrieve

from flowrep.models import base_models

Expand Down Expand Up @@ -72,3 +73,7 @@ def check_outputs_when_not_unpacking(self):
f"unpack_mode={self.unpack_mode.value}"
)
return self

def __call__(self, *args, **kwargs):
func = retrieve.import_from_string(self.reference.info.fully_qualified_name)
return func(*args, **kwargs)
10 changes: 10 additions & 0 deletions flowrep/models/nodes/workflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Literal

import pydantic
from pyiron_snippets import retrieve

from flowrep.models import base_models, edge_models, subgraph_validation

Expand Down Expand Up @@ -83,3 +84,12 @@ def validate_internal_data_completeness(self):
self.nodes, list(self.input_edges) + list(self.edges)
)
return self

def __call__(self, *args, **kwargs):
if self.reference is None:
raise ValueError(
f"{self.__class__.__name__} recipes are only callable when they are "
f"attached to an underlying python definiton in their reference field."
)
func = retrieve.import_from_string(self.reference.info.fully_qualified_name)
return func(*args, **kwargs)
16 changes: 8 additions & 8 deletions flowrep/models/parsers/atomic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,13 @@ def get_labeled_recipe(
scope: object_scope.ScopeProxy,
info_factory: versions.VersionInfoFactory,
) -> helper_models.LabeledNode:
child_call = cast(
FunctionType, object_scope.resolve_symbol_to_object(ast_call.func, scope)
)
child_call = object_scope.resolve_symbol_to_object(ast_call.func, scope)
# Since it is the .func attribute of an ast.Call,
# the retrieved object had better be a function
if hasattr(child_call, "flowrep_recipe"):
child_recipe = child_call.flowrep_recipe
function_call = cast(FunctionType, child_call)
label_prefix = function_call.__name__
if hasattr(function_call, "flowrep_recipe"):
child_recipe = function_call.flowrep_recipe
if hasattr(child_recipe, "reference") and isinstance(
child_recipe.reference.info, versions.VersionInfo
):
Expand All @@ -278,11 +278,11 @@ def get_labeled_recipe(
)
else:
child_recipe = parse_atomic(
child_call,
function_call,
version_scraping=info_factory.version_scraping,
forbid_main=info_factory.forbid_main,
forbid_locals=info_factory.forbid_locals,
require_version=info_factory.require_version,
)
child_name = label_helpers.unique_suffix(child_call.__name__, existing_names)
return helper_models.LabeledNode(label=child_name, node=child_recipe)
label = label_helpers.unique_suffix(label_prefix, existing_names)
return helper_models.LabeledNode(label=label, node=child_recipe)
25 changes: 25 additions & 0 deletions tests/unit/models/nodes/test_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,30 @@ def test_no_inputs_with_inputs_with_defaults_rejected(self):
)


class TestAtomicNodeCall(unittest.TestCase):
def test_call_on_bad_reference(self):
recipe = makers.make_atomic(inputs=["x", "y"])
with self.assertRaises(
ModuleNotFoundError,
msg="Neither the recipe nor the call validate that the underlying python "
"funcion reference is actually there -- so attempting to run it should "
"only fail at the point we actually try to import our made-up reference",
):
recipe(x=1, y=2)

def test_call(self):
"""
Calling atomic recipes should import and execute their underlying function
"""
recipe = atomic_model.AtomicNode(
reference=makers.make_reference("builtins", "int"),
inputs=["x_str"],
outputs=["x_int"],
)
result = recipe("1")
self.assertIsInstance(result, int)
self.assertEqual(result, 1)


if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions tests/unit/models/nodes/test_for_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ def test_valid_for_node_with_both_nested_and_zipped(self):
self.assertEqual(for_node.nested_ports, ["a"])
self.assertEqual(for_node.zipped_ports, ["b", "c"])

def test_call_raises(self):
recipe = for_model.ForNode(
inputs=["x"],
outputs=[],
body_node=makers.make_labeled_atomic(
"body",
inputs=["item"],
outputs=["result"],
inputs_with_defaults=["item"],
),
input_edges={},
output_edges={},
nested_ports=["item"],
)
with self.assertRaises(NotImplementedError):
recipe(42)


class TestForNodeLoopPortValidation(unittest.TestCase):
def test_no_iteration_ports_rejected(self):
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/models/nodes/test_if_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _make_valid_if_node(n_cases=1, with_else=True):
)


class TestIfNodeBasicConstruction(unittest.TestCase):
class TestIfNodeBasic(unittest.TestCase):
def test_schema_generation(self):
"""model_json_schema() fails if forward refs aren't resolved."""
if_model.IfNode.model_json_schema()
Expand Down Expand Up @@ -100,6 +100,11 @@ def test_type_field_immutable(self):
node.type = base_models.RecipeElementType.WORKFLOW
self.assertIn("frozen", str(ctx.exception).lower())

def test_call_raises(self):
recipe = _make_valid_if_node()
with self.assertRaises(NotImplementedError):
recipe(42)


class TestIfNodeCasesValidation(unittest.TestCase):
def test_empty_cases_rejected(self):
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/models/nodes/test_try_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _make_valid_try_node(n_exception_cases=1):
)


class TestTryNodeBasicConstruction(unittest.TestCase):
class TestTryNodeBasic(unittest.TestCase):
def test_schema_generation(self):
"""model_json_schema() fails if forward refs aren't resolved."""
try_model.TryNode.model_json_schema()
Expand All @@ -86,6 +86,11 @@ def test_valid_multiple_exception_cases(self):
node = _make_valid_try_node(n_exception_cases=3)
self.assertEqual(len(node.exception_cases), 3)

def test_call_raises(self):
recipe = _make_valid_try_node()
with self.assertRaises(NotImplementedError):
recipe(42)


class TestTryNodeExceptionCasesValidation(unittest.TestCase):
def test_empty_exception_cases_rejected(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/models/nodes/test_while_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def test_valid_fully_wired(self):
self.assertEqual(len(wn.input_edges), 3)
self.assertEqual(len(wn.output_edges), 1)

def test_call_raises(self):
recipe = make_valid_while_node()
with self.assertRaises(NotImplementedError):
recipe(42)


class TestWhileNodeIOValidation(unittest.TestCase):
def test_duplicate_inputs_rejected(self):
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/models/nodes/test_workflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def _make_child(with_defaults: bool) -> atomic_model.AtomicNode:
)


def _some_wf(x, y):
return y, x


class TestWorkflowNodeStructure(unittest.TestCase):
"""Tests for protocol validation and schema availability."""

Expand Down Expand Up @@ -929,5 +933,40 @@ def test_not_subset_rejected(self):
self.assertIn("z", str(ctx.exception))


class TestWorkflowNodeCall(unittest.TestCase):
def test_call_without_reference_raises(self):
recipe = workflow_model.WorkflowNode(
inputs=[],
outputs=[],
nodes={},
input_edges={},
edges={},
output_edges={},
)
with self.assertRaises(
ValueError,
msg="Calling a workflow recipe without a reference should alert us to "
"the reference's absence",
) as ctx:
recipe()
self.assertIn("only callable when", str(ctx.exception))
self.assertIn("reference field", str(ctx.exception))

def test_call_with_reference(self):
recipe = workflow_model.WorkflowNode(
inputs=["x", "y"],
outputs=["y", "x"],
nodes={},
input_edges={},
edges={},
output_edges={"y": "x", "x": "y"},
reference=makers.make_reference(_some_wf.__module__, _some_wf.__name__),
)
xi, yi = 1, 2
yo, xo = recipe(xi, yi)
self.assertEqual(xo, xi)
self.assertEqual(yo, yi)


if __name__ == "__main__":
unittest.main()
Loading