Skip to content

Commit

Permalink
update dask vine executor to new dask graphs (#4015)
Browse files Browse the repository at this point in the history
* DaskVine to new graph representation

* update simple graph example

* fix bug with depth

* always convert from legacy representation, for now

* check for dask in test

* do not import DaskVineDag if dask not available

* update function calls

* lint

* handle generic container graph nodes

* add warning about dask version

* example_to_revert

* remove print statement
  • Loading branch information
btovar authored Jan 17, 2025
1 parent 5ae1a95 commit c77ab6f
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 222 deletions.
3 changes: 1 addition & 2 deletions taskvine/src/bindings/python3/ndcctools/taskvine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@
LibraryTask,
FunctionCall,
)
from .dask_dag import DaskVineDag

from . import cvine

try:
from .dask_executor import DaskVine
from .dask_dag import DaskVineDag
except ImportError as e:
print(f"DaskVine not available. Couldn't find module: {e.name}")

Expand Down
216 changes: 93 additions & 123 deletions taskvine/src/bindings/python3/ndcctools/taskvine/dask_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,213 +2,183 @@
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

from uuid import uuid4
from collections import defaultdict
import dask._task_spec as dts


class DaskVineDag:
"""A directed graph that encodes the steps and state a computation needs.
Single computations are encoded as s-expressions, therefore it is 'upside-down',
in the sense that the children of a node are the nodes required to compute it.
E.g., for
Single computations are encoded as dts.Task's, with dependecies expressed as the keys needed by the task.
dsk = {'x': 1,
'y': 2,
'z': (add, 'x', 'y'),
'w': (sum, ['x', 'y', 'z']),
'v': [(sum, ['w', 'z']), 2]
'z': dts.Task('z', add, dts.TaskRef('x'), dts.TaskRef('y'))
'w': dts.Task('w', sum, [dts.TaskRef('x'), dts.TaskRef('y'), dts.TaskRef('z')]),
'v': dts.Task('v', sum, [dts.TaskRef('w'), dts.TaskRef('z')])
't': dts.Task('v', sum, [dts.TaskRef('v'), 2])
}
'z' has as children 'x' and 'y'.
Each node is referenced by its key. When the value of a key is list of
sexprs, like 'v' above, and low_memory_mode is True, then a key is automatically computed recursively
for each computation.
'z' has as dependecies 'x' and 'y'.
Computation is done lazily. The DaskVineDag is initialized from a task graph, but not
computation is decoded. To use the DaskVineDag:
- DaskVineDag.set_targets(keys): Request the computation associated with key to be decoded.
- DaskVineDag.get_ready(): A list of [key, sexpr] of expressions that are ready
to be executed.
- DaskVineDag.get_ready(): A list of dts.Task that are ready to be executed.
- DaskVineDag.set_result(key, value): Sets the result of key to value.
- DaskVineDag.get_result(key): Get result associated with key. Raises DagNoResult
- DaskVineDag.has_result(key): Whether the key has a computed result. """

@staticmethod
def hashable(s):
try:
hash(s)
return True
except TypeError:
return False

@staticmethod
def keyp(s):
return DaskVineDag.hashable(s) and not DaskVineDag.taskp(s)
return DaskVineDag.hashable(s) and not DaskVineDag.taskref(s) and not DaskVineDag.taskp(s)

@staticmethod
def taskp(s):
return isinstance(s, tuple) and len(s) > 0 and callable(s[0])
def taskref(s):
return isinstance(s, (dts.TaskRef, dts.Alias))

@staticmethod
def listp(s):
return isinstance(s, list)
def taskp(s):
return isinstance(s, dts.Task)

@staticmethod
def symbolp(s):
return not (DaskVineDag.taskp(s) or DaskVineDag.listp(s))
def containerp(s):
return isinstance(s, dts.NestedContainer)

@staticmethod
def hashable(s):
try:
hash(s)
return True
except TypeError:
return False
def symbolp(s):
return isinstance(s, dts.DataNode)

def __init__(self, dsk, low_memory_mode=False):
def __init__(self, dsk):
self._dsk = dsk

# child -> parents. I.e., which parents needs the result of child
self._parents_of = defaultdict(lambda: set())
# For a key, the set of keys that need it to perform a computation.
self._needed_by = defaultdict(lambda: set())

# parent->children still waiting for result. A key is ready to be computed when children left is []
self._missing_of = {}
# For a key, the subset of self._needed_by[key] that still need to be completed.
# Only useful for gc.
self._pending_needed_by = defaultdict(lambda: set())

# parent->nchildren get the number of children for parent computation
self._children_of = {}
# For a key, the set of keys that it needs for computation.
self._dependencies_of = {}

# For a key, the set of keys with a pending result for they key to be computed.
# When the set is empty, the key is ready to be computed. It is always a subset
# of self._dependencies_of[key].
self._missing_of = {}

# key->value of its computation
self._result_of = {}

# child -> nodes that use the child as an input, and that have not been completed
self._pending_parents_of = defaultdict(lambda: set())

# key->depth. The shallowest level the key is found
self._depth_of = defaultdict(lambda: float('inf'))

# target keys that the dag should compute
self._targets = set()

self._working_graph = dict(dsk)
if low_memory_mode:
self._flatten_graph()

self.initialize_graph()

def left_to_compute(self):
return len(self._working_graph) - len(self._result_of)

def graph_keyp(self, s):
if DaskVineDag.keyp(s):
return s in self._working_graph
return False

def depth_of(self, key):
return self._depth_of[key]

def initialize_graph(self):
for key, sexpr in self._working_graph.items():
self.set_relations(key, sexpr)

def find_dependencies(self, sexpr, depth=0):
dependencies = set()
if self.graph_keyp(sexpr):
dependencies.add(sexpr)
self._depth_of[sexpr] = min(depth, self._depth_of[sexpr])
elif not DaskVineDag.symbolp(sexpr):
for sub in sexpr:
dependencies.update(self.find_dependencies(sub, depth + 1))
return dependencies
for task in self._working_graph.values():
self.set_relations(task)

def set_relations(self, key, sexpr):
sexpr = self._working_graph[key]
for task in self._working_graph.values():
if isinstance(task, dts.DataNode):
self._depth_of[task.key] = 0
self.set_result(task.key, task.value)

self._children_of[key] = self.find_dependencies(sexpr)
self._depth_of[key] = max([self._depth_of[c] for c in self._children_of[key]]) + 1 if self._children_of[key] else 0

self._missing_of[key] = set(self._children_of[key])

for c in self._children_of[key]:
self._parents_of[c].add(key)
self._pending_parents_of[c].add(key)
def set_relations(self, task):
self._dependencies_of[task.key] = task.dependencies
self._missing_of[task.key] = set(self._dependencies_of[task.key])
for c in self._dependencies_of[task.key]:
self._needed_by[c].add(task.key)
self._pending_needed_by[c].add(task.key)

def get_ready(self):
""" List of [(key, sexpr),...] ready for computation.
""" List of dts.Task ready for computation.
This call should be used only for
bootstrapping. Further calls should use DaskVineDag.set_result to discover
the new computations that become ready to be executed. """
rs = {}
for (key, cs) in self._missing_of.items():
if self.has_result(key) or cs:
continue
sexpr = self._working_graph[key]
if self.graph_keyp(sexpr):
rs.update(self.set_result(key, self.get_result(sexpr)))
elif self.symbolp(sexpr):
rs.update(self.set_result(key, sexpr))
node = self._working_graph[key]
if self.taskref(node):
rs.update(self.set_result(key, self.get_result(node.key)))
elif self.symbolp(node):
rs.update(self.set_result(key, node))
else:
rs[key] = (key, sexpr)
rs[key] = node

for r in rs:
if self._dependencies_of[r]:
self._depth_of[r] = min(self._depth_of[d] for d in self._dependencies_of[r]) + 1
else:
self._depth_of[r] = 0

return rs.values()

def set_result(self, key, value):
""" Sets new result and propagates in the DaskVineDag. Returns a list of [(key, sexpr),...]
""" Sets new result and propagates in the DaskVineDag. Returns a list of dts.Task
of computations that become ready to be executed """
rs = {}
self._result_of[key] = value
for p in self._parents_of[key]:
for p in self._pending_needed_by[key]:
self._missing_of[p].discard(key)

if self._missing_of[p]:
# the key p still has dependencies unmet...
continue

sexpr = self._working_graph[p]
if self.graph_keyp(sexpr):
node = self._working_graph[p]
if self.taskref(node):
rs.update(
self.set_result(p, self.get_result(sexpr))
self.set_result(p, self.get_result(node))
) # case e.g, "x": "y", and we just set the value of "y"
elif self.symbolp(sexpr):
rs.update(self.set_result(p, sexpr))
elif self.symbolp(node):
rs.update(self.set_result(p, node))
else:
rs[p] = (p, sexpr)
rs[p] = node

for c in self._children_of[key]:
self._pending_parents_of[c].discard(key)
for r in rs:
if self._dependencies_of[r]:
self._depth_of[r] = min(self._depth_of[d] for d in self._dependencies_of[r]) + 1
else:
self._depth_of[r] = 0

return rs.values()
for c in self._dependencies_of[key]:
self._pending_needed_by[c].discard(key)

def _flatten_graph(self):
""" Recursively decomposes a sexpr associated with key, so that its arguments, if any
are keys. """
for key in list(self._working_graph.keys()):
self.flatten_rec(key, self._working_graph[key], toplevel=True)
return rs.values()

def _add_second_targets(self, key):
v = self._working_graph[key]
if self.graph_keyp(v):
if self.taskref(v):
lst = [v]
elif DaskVineDag.listp(v):
elif DaskVineDag.containerp(v):
lst = v
else:
return
for c in lst:
if self.graph_keyp(c):
self._targets.add(c)
self._add_second_targets(c)

def flatten_rec(self, key, sexpr, toplevel=False):
if key in self._working_graph and not toplevel:
return
if DaskVineDag.symbolp(sexpr):
return

nargs = []
next_flat = []
cons = type(sexpr)

for arg in sexpr:
if DaskVineDag.symbolp(arg):
nargs.append(arg)
else:
next_key = uuid4()
nargs.append(next_key)
next_flat.append((next_key, arg))

self._working_graph[key] = cons(nargs)
for (n, a) in next_flat:
self.flatten_rec(n, a)
if self.taskref(c):
self._targets.add(c.key)
self._add_second_targets(c.key)

def has_result(self, key):
return key in self._result_of
Expand All @@ -219,17 +189,17 @@ def get_result(self, key):
except KeyError:
raise DaskVineNoResult(key)

def get_children(self, key):
return self._children_of[key]
def get_dependencies(self, key):
return self._dependencies_of[key]

def get_missing_children(self, key):
def get_missing_dependencies(self, key):
return self._missing_of[key]

def get_parents(self, key):
return self._parents_of[key]
def get_needed_by(self, key):
return self._needed_by[key]

def get_pending_parents(self, key):
return self._pending_parents_of[key]
def get_pending_needed_by(self, key):
return self._pending_needed_by[key]

def set_targets(self, keys):
""" Values of keys that need to be computed. """
Expand Down
Loading

0 comments on commit c77ab6f

Please sign in to comment.