Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions graphkit/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from itertools import chain

from boltons.setutils import IndexedSet as iset

from .base import Operation, NetworkOperation
from .network import Network
from .modifiers import optional
Expand All @@ -28,7 +30,7 @@ def _compute(self, named_inputs, outputs=None):

result = zip(self.provides, result)
if outputs:
outputs = set(outputs)
outputs = sorted(set(outputs))
result = filter(lambda x: x[0] in outputs, result)

return dict(result)
Expand Down Expand Up @@ -185,22 +187,23 @@ def __call__(self, *operations):

# If merge is desired, deduplicate operations before building network
if self.merge:
merge_set = set()
merge_set = iset() # Preseve given node order.
for op in operations:
if isinstance(op, NetworkOperation):
net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps)
merge_set.update(net_ops)
else:
merge_set.add(op)
operations = list(merge_set)
operations = merge_set

def order_preserving_uniquifier(seq, seen=None):
seen = seen if seen else set()
seen = seen if seen else set() # unordered, not iterated
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]

provides = order_preserving_uniquifier(chain(*[op.provides for op in operations]))
needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides))
needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]),
set(provides)) # unordered, not iterated

# compile network
net = Network()
Expand Down
15 changes: 9 additions & 6 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from io import StringIO

from boltons.setutils import IndexedSet as iset

from .base import Operation


Expand Down Expand Up @@ -107,7 +109,7 @@ def compile(self):
self.steps = []

# create an execution order such that each layer's needs are provided.
ordered_nodes = list(nx.dag.topological_sort(self.graph))
ordered_nodes = iset(nx.topological_sort(self.graph))

# add Operations evaluation steps, and instructions to free data.
for i, node in enumerate(ordered_nodes):
Expand Down Expand Up @@ -163,7 +165,7 @@ def _find_necessary_steps(self, outputs, inputs):
"""

# return steps if it has already been computed before for this set of inputs and outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set, iset)) else outputs
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
if cache_key in self._necessary_steps_cache:
Expand All @@ -175,7 +177,7 @@ def _find_necessary_steps(self, outputs, inputs):
# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
Expand All @@ -186,15 +188,15 @@ def _find_necessary_steps(self, outputs, inputs):
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set()
unnecessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
Expand Down Expand Up @@ -266,7 +268,7 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
necessary_nodes = self._find_necessary_steps(outputs, named_inputs)

# this keeps track of all nodes that have already executed
has_executed = set()
has_executed = set() # unordered, not iterated

# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
Expand Down Expand Up @@ -464,6 +466,7 @@ def ready_to_schedule_operation(op, has_executed, graph):
A boolean indicating whether the operation may be scheduled for
execution based on what has already been executed.
"""
# unordered, not iterated
dependencies = set(filter(lambda v: isinstance(v, Operation),
nx.ancestors(graph, op)))
return dependencies.issubset(has_executed)
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
author_email='huyng@yahoo-inc.com',
url='http://github.com/yahoo/graphkit',
packages=['graphkit'],
install_requires=['networkx'],
install_requires=[
"networkx; python_version >= '3.5'",
"networkx == 2.2; python_version < '3.5'",
"boltons" # for IndexSet
],
extras_require={
'plot': ['pydot', 'matplotlib']
},
Expand Down