Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions graphkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class Operation(object):
specific application.
"""

#: Owning :class:`~.network.Network`, set when added in a network.
#: Needed by `_compute()` to detect *optional needs* from edge-attributes.
net = None

def __init__(self, **kwargs):
"""
Create a new layer instance.
Expand Down
15 changes: 13 additions & 2 deletions graphkit/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@ def __init__(self, **kwargs):
Operation.__init__(self, **kwargs)

def _compute(self, named_inputs, outputs=None):
inputs = [named_inputs[d] for d in self.needs if not isinstance(d, optional)]
assert self.net

inputs = [
named_inputs[n]
for n in self.needs
if 'optional' not in self.net.graph.get_edge_data(n, self)
]

# Find any optional inputs in named_inputs. Get only the ones that
# are present there, no extra `None`s.
optionals = {n: named_inputs[n] for n in self.needs if isinstance(n, optional) and n in named_inputs}
optionals = {
n: named_inputs[n]
for n in self.needs
if 'optional' in self.net.graph.get_edge_data(n, self)
and n in named_inputs
}

# Combine params and optionals into one big glob of keyword arguments.
kwargs = {k: v for d in (self.params, optionals) for k, v in d.items()}
Expand Down
82 changes: 78 additions & 4 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from io import StringIO

from .base import Operation
from . import modifiers


class DataPlaceholderNode(str):
Expand Down Expand Up @@ -73,9 +74,16 @@ def add_op(self, operation):
# assert layer is only added once to graph
assert operation not in self.graph.nodes(), "Operation may only be added once"

# functionalOperations don't have that set.
if not operation.net:
operation.net = self

# add nodes and edges to graph describing the data needs for this layer
for n in operation.needs:
self.graph.add_edge(DataPlaceholderNode(n), operation)
if isinstance(n, modifiers.optional):
self.graph.add_edge(DataPlaceholderNode(n), operation, optional=True)
else:
self.graph.add_edge(DataPlaceholderNode(n), operation)

# add nodes and edges to graph describing what this layer provides
for p in operation.provides:
Expand Down Expand Up @@ -107,7 +115,7 @@ def compile(self):
self.steps = []

# create an execution order such that each layer's needs are provided.
ordered_nodes = list(nx.dag.topological_sort(self.graph))
ordered_nodes = list(nx.topological_sort(self.graph))

# add Operations evaluation steps, and instructions to free data.
for i, node in enumerate(ordered_nodes):
Expand Down Expand Up @@ -141,6 +149,65 @@ def compile(self):
raise TypeError("Unrecognized network graph node")


def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited):
"""
Recusrively check if operation inputs are given/calculated (satisfied), or not.

:param satisfiables:
the set to populate with satisfiable operations

:param visited:
a cache of operations & needs, not to visit them again
:return:
true if opearation is satisfiable
"""
assert isinstance(operation, Operation), (
"Expected Operation, got:",
type(operation),
)

if operation in visited:
return visited[operation]


def is_need_satisfiable(need):
if need in visited:
return visited[need]

if need in inputs:
satisfied = True
else:
need_providers = list(self.graph.predecessors(need))
satisfied = bool(need_providers) and any(
self._collect_satisfiable_needs(op, inputs, satisfiables, visited)
for op in need_providers
)
visited[need] = satisfied

return satisfied

satisfied = all(
is_need_satisfiable(need)
for need in operation.needs
if 'optional' not in self.graph.get_edge_data(need, operation)
)
if satisfied:
satisfiables.add(operation)
visited[operation] = satisfied

return satisfied


def _collect_satisfiable_operations(self, nodes, inputs):
satisfiables = set()
visited = {}
for node in nodes:
if node not in visited and isinstance(node, Operation):
self._collect_satisfiable_needs(node, inputs, satisfiables, visited)

return satisfiables


def _find_necessary_steps(self, outputs, inputs):
"""
Determines what graph steps need to pe run to get to the requested
Expand Down Expand Up @@ -204,6 +271,13 @@ def _find_necessary_steps(self, outputs, inputs):
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# See https://github.com/yahoo/graphkit/pull/18
#
satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs)
for node in list(necessary_nodes):
if isinstance(node, Operation) and node not in satisfiables:
necessary_nodes.remove(node)

necessary_steps = [step for step in self.steps if step in necessary_nodes]

Expand Down Expand Up @@ -422,8 +496,8 @@ def get_node_name(a):

# save plot
if filename:
basename, ext = os.path.splitext(filename)
with open(filename, "w") as fh:
_basename, ext = os.path.splitext(filename)
with open(filename, "wb") as fh:
if ext.lower() == ".png":
fh.write(g.create_png())
elif ext.lower() == ".dot":
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
author_email='huyng@yahoo-inc.com',
url='http://github.com/yahoo/graphkit',
packages=['graphkit'],
install_requires=['networkx'],
install_requires=[
"networkx; python_version >= '3.5'",
"networkx == 2.2; python_version < '3.5'",
],
extras_require={
'plot': ['pydot', 'matplotlib']
},
Expand Down
46 changes: 44 additions & 2 deletions test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle

from pprint import pprint
from operator import add
from operator import add, mul, floordiv
from numpy.testing import assert_raises

import graphkit.network as network
Expand All @@ -32,7 +32,9 @@ def mul_op1(a, b):
@operation(name='pow_op1', needs='sum_ab', provides=['sum_ab_p1', 'sum_ab_p2', 'sum_ab_p3'], params={'exponent': 3})
def pow_op1(a, exponent=2):
return [math.pow(a, y) for y in range(1, exponent+1)]


# `_compute()` needs a` nx-DiGraph in op's `net` attribute.
compose("mock graph")(pow_op1)
print(pow_op1._compute({'sum_ab':2}, ['sum_ab_p2']))

# Partial operation that is bound at a later time
Expand Down Expand Up @@ -69,6 +71,22 @@ def pow_op1(a, exponent=2):
# net.plot(show=True)


def test_operations_with_partial_inputs_ignored():
graph = compose(name="graph")(
operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul),
operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv),
operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add),
)

exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
assert graph({"a": 10, "b1": 2, "c": 1}) == exp
assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21}

exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
assert graph({"a": 10, "b2": 2, "c": 1}) == exp
assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6}


def test_network_simple_merge():

sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add)
Expand Down Expand Up @@ -208,6 +226,30 @@ def addplusplus(a, b, c=0):
assert results['sum'] == sum(named_inputs.values())


def test_optional_per_function():
# Test that the same need can be both optional and not on different operations.
net = compose(name='partial_optionals')(
operation(name='sum', needs=['a', 'b'], provides='a+b')(add),
operation(name='sub_opt', needs=['a', modifiers.optional('b')], provides='a+b')
(lambda a, b=10: a - b),
)

named_inputs = {'a': 1, 'b': 2}
results = net(named_inputs)
assert 'a+b' in results
assert results['a+b'] == sum(named_inputs.values())
results = net(named_inputs, ['a+b'])
assert 'a+b' in results
assert results['a+b'] == sum(named_inputs.values())

named_inputs = {'a': 1}
results = net(named_inputs)
assert 'a+b' in results
assert results['a+b'] == -9
assert 'a+b' in results
assert results['a+b'] == -9


def test_deleted_optional():
# Test that DeleteInstructions included for optionals do not raise
# exceptions when the corresponding input is not prodided.
Expand Down