diff --git a/graphkit/base.py b/graphkit/base.py index 1c04e8d5..d5d27353 100644 --- a/graphkit/base.py +++ b/graphkit/base.py @@ -26,6 +26,10 @@ class Operation(object): specific application. """ + #: Owning :class:`~.network.Network`, set when added in a network. + #: Needed by `_compute()` to detect *optional needs* from edge-attributes. + net = None + def __init__(self, **kwargs): """ Create a new layer instance. diff --git a/graphkit/functional.py b/graphkit/functional.py index 65388973..baa79474 100644 --- a/graphkit/functional.py +++ b/graphkit/functional.py @@ -14,11 +14,22 @@ def __init__(self, **kwargs): Operation.__init__(self, **kwargs) def _compute(self, named_inputs, outputs=None): - inputs = [named_inputs[d] for d in self.needs if not isinstance(d, optional)] + assert self.net + + inputs = [ + named_inputs[n] + for n in self.needs + if 'optional' not in self.net.graph.get_edge_data(n, self) + ] # Find any optional inputs in named_inputs. Get only the ones that # are present there, no extra `None`s. - optionals = {n: named_inputs[n] for n in self.needs if isinstance(n, optional) and n in named_inputs} + optionals = { + n: named_inputs[n] + for n in self.needs + if 'optional' in self.net.graph.get_edge_data(n, self) + and n in named_inputs + } # Combine params and optionals into one big glob of keyword arguments. kwargs = {k: v for d in (self.params, optionals) for k, v in d.items()} diff --git a/graphkit/network.py b/graphkit/network.py index 0df3ddf8..980a01e8 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -8,6 +8,7 @@ from io import StringIO from .base import Operation +from . import modifiers class DataPlaceholderNode(str): @@ -73,9 +74,16 @@ def add_op(self, operation): # assert layer is only added once to graph assert operation not in self.graph.nodes(), "Operation may only be added once" + # functionalOperations don't have that set. + if not operation.net: + operation.net = self + # add nodes and edges to graph describing the data needs for this layer for n in operation.needs: - self.graph.add_edge(DataPlaceholderNode(n), operation) + if isinstance(n, modifiers.optional): + self.graph.add_edge(DataPlaceholderNode(n), operation, optional=True) + else: + self.graph.add_edge(DataPlaceholderNode(n), operation) # add nodes and edges to graph describing what this layer provides for p in operation.provides: @@ -107,7 +115,7 @@ def compile(self): self.steps = [] # create an execution order such that each layer's needs are provided. - ordered_nodes = list(nx.dag.topological_sort(self.graph)) + ordered_nodes = list(nx.topological_sort(self.graph)) # add Operations evaluation steps, and instructions to free data. for i, node in enumerate(ordered_nodes): @@ -141,6 +149,65 @@ def compile(self): raise TypeError("Unrecognized network graph node") + def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited): + """ + Recusrively check if operation inputs are given/calculated (satisfied), or not. + + :param satisfiables: + the set to populate with satisfiable operations + + :param visited: + a cache of operations & needs, not to visit them again + :return: + true if opearation is satisfiable + """ + assert isinstance(operation, Operation), ( + "Expected Operation, got:", + type(operation), + ) + + if operation in visited: + return visited[operation] + + + def is_need_satisfiable(need): + if need in visited: + return visited[need] + + if need in inputs: + satisfied = True + else: + need_providers = list(self.graph.predecessors(need)) + satisfied = bool(need_providers) and any( + self._collect_satisfiable_needs(op, inputs, satisfiables, visited) + for op in need_providers + ) + visited[need] = satisfied + + return satisfied + + satisfied = all( + is_need_satisfiable(need) + for need in operation.needs + if 'optional' not in self.graph.get_edge_data(need, operation) + ) + if satisfied: + satisfiables.add(operation) + visited[operation] = satisfied + + return satisfied + + + def _collect_satisfiable_operations(self, nodes, inputs): + satisfiables = set() + visited = {} + for node in nodes: + if node not in visited and isinstance(node, Operation): + self._collect_satisfiable_needs(node, inputs, satisfiables, visited) + + return satisfiables + + def _find_necessary_steps(self, outputs, inputs): """ Determines what graph steps need to pe run to get to the requested @@ -204,6 +271,13 @@ def _find_necessary_steps(self, outputs, inputs): # Get rid of the unnecessary nodes from the set of necessary ones. necessary_nodes -= unnecessary_nodes + # Drop (un-satifiable) operations with partial inputs. + # See https://github.com/yahoo/graphkit/pull/18 + # + satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs) + for node in list(necessary_nodes): + if isinstance(node, Operation) and node not in satisfiables: + necessary_nodes.remove(node) necessary_steps = [step for step in self.steps if step in necessary_nodes] @@ -422,8 +496,8 @@ def get_node_name(a): # save plot if filename: - basename, ext = os.path.splitext(filename) - with open(filename, "w") as fh: + _basename, ext = os.path.splitext(filename) + with open(filename, "wb") as fh: if ext.lower() == ".png": fh.write(g.create_png()) elif ext.lower() == ".dot": diff --git a/setup.py b/setup.py index bd7883f4..d3dfec84 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,10 @@ author_email='huyng@yahoo-inc.com', url='http://github.com/yahoo/graphkit', packages=['graphkit'], - install_requires=['networkx'], + install_requires=[ + "networkx; python_version >= '3.5'", + "networkx == 2.2; python_version < '3.5'", + ], extras_require={ 'plot': ['pydot', 'matplotlib'] }, diff --git a/test/test_graphkit.py b/test/test_graphkit.py index bd97b317..2bf7d915 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -5,7 +5,7 @@ import pickle from pprint import pprint -from operator import add +from operator import add, mul, floordiv from numpy.testing import assert_raises import graphkit.network as network @@ -32,7 +32,9 @@ def mul_op1(a, b): @operation(name='pow_op1', needs='sum_ab', provides=['sum_ab_p1', 'sum_ab_p2', 'sum_ab_p3'], params={'exponent': 3}) def pow_op1(a, exponent=2): return [math.pow(a, y) for y in range(1, exponent+1)] - + + # `_compute()` needs a` nx-DiGraph in op's `net` attribute. + compose("mock graph")(pow_op1) print(pow_op1._compute({'sum_ab':2}, ['sum_ab_p2'])) # Partial operation that is bound at a later time @@ -69,6 +71,22 @@ def pow_op1(a, exponent=2): # net.plot(show=True) +def test_operations_with_partial_inputs_ignored(): + graph = compose(name="graph")( + operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul), + operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv), + operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add), + ) + + exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21} + assert graph({"a": 10, "b1": 2, "c": 1}) == exp + assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21} + + exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6} + assert graph({"a": 10, "b2": 2, "c": 1}) == exp + assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6} + + def test_network_simple_merge(): sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add) @@ -208,6 +226,30 @@ def addplusplus(a, b, c=0): assert results['sum'] == sum(named_inputs.values()) +def test_optional_per_function(): + # Test that the same need can be both optional and not on different operations. + net = compose(name='partial_optionals')( + operation(name='sum', needs=['a', 'b'], provides='a+b')(add), + operation(name='sub_opt', needs=['a', modifiers.optional('b')], provides='a+b') + (lambda a, b=10: a - b), + ) + + named_inputs = {'a': 1, 'b': 2} + results = net(named_inputs) + assert 'a+b' in results + assert results['a+b'] == sum(named_inputs.values()) + results = net(named_inputs, ['a+b']) + assert 'a+b' in results + assert results['a+b'] == sum(named_inputs.values()) + + named_inputs = {'a': 1} + results = net(named_inputs) + assert 'a+b' in results + assert results['a+b'] == -9 + assert 'a+b' in results + assert results['a+b'] == -9 + + def test_deleted_optional(): # Test that DeleteInstructions included for optionals do not raise # exceptions when the corresponding input is not prodided.