diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 5b4a1e21..1a580ab6 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -5,7 +5,7 @@ import weakref from collections import defaultdict from collections.abc import Generator -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, NamedTuple import dask import pandas as pd @@ -29,6 +29,10 @@ ] +class BranchId(NamedTuple): + branch_id: int + + def _unpack_collections(o): if isinstance(o, Expr): return o @@ -44,8 +48,13 @@ class Expr: _defaults = {} _instances = weakref.WeakValueDictionary() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, _branch_id=None, **kwargs): operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + elif _branch_id is None: + _branch_id = BranchId(0) + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) @@ -54,6 +63,7 @@ def __new__(cls, *args, **kwargs): assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] + inst._branch_id = _branch_id _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -116,7 +126,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): elif is_arraylike(op): op = "" header = self._tree_repr_argument_construction(i, op, header) - + if self._branch_id.branch_id != 0: + header = self._tree_repr_argument_construction( + i + 1, f" branch_id={self._branch_id.branch_id}", header + ) lines = [header] + lines lines = [" " * indent + line for line in lines] @@ -218,7 +231,7 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def rewrite(self, kind: str): + def rewrite(self, kind: str, cache): """Rewrite an expression This leverages the ``._{kind}_down`` and ``._{kind}_up`` @@ -231,6 +244,9 @@ def rewrite(self, kind: str): changed: whether or not any change occured """ + if self._name in cache: + return cache[self._name] + expr = self down_name = f"_{kind}_down" up_name = f"_{kind}_up" @@ -267,7 +283,8 @@ def rewrite(self, kind: str): changed = False for operand in expr.operands: if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) + new = operand.rewrite(kind=kind, cache=cache) + cache[operand._name] = new if new._name != operand._name: changed = True else: @@ -275,13 +292,37 @@ def rewrite(self, kind: str): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) continue else: break return expr + def _reuse_up(self, parent): + return + + def _reuse_down(self): + if not self.dependencies(): + return + return self._bubble_branch_id_down() + + def _bubble_branch_id_down(self): + b_id = self._branch_id + if b_id.branch_id <= 0: + return + if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): + ops = [ + op._substitute_branch_id(b_id) if isinstance(op, Expr) else op + for op in self.operands + ] + return type(self)(*ops) + + def _substitute_branch_id(self, branch_id): + if self._branch_id.branch_id != 0: + return self + return type(self)(*self.operands, branch_id) + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -346,7 +387,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) break @@ -391,7 +432,7 @@ def lower_once(self): new_operands.append(new) if changed: - out = type(out)(*new_operands) + out = type(out)(*new_operands, _branch_id=out._branch_id) return out @@ -427,7 +468,9 @@ def _lower(self): @functools.cached_property def _name(self): return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) @property @@ -580,7 +623,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr: else: new_operands.append(operand) if changed: - return type(self)(*new_operands) + return type(self)(*new_operands, _branch_id=self._branch_id) return self def _node_label_args(self): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 83682f19..35e510f2 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -53,6 +53,7 @@ from tlz import merge_sorted, partition, unique from dask_expr import _core as core +from dask_expr._core import BranchId from dask_expr._util import ( _calc_maybe_new_divisions, _convert_to_list, @@ -502,7 +503,7 @@ def _name(self): head = funcname(self.operation) else: head = funcname(type(self)).lower() - return head + "-" + _tokenize_deterministic(*self.operands) + return head + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def _blockwise_arg(self, arg, i): """Return a Blockwise-task argument""" @@ -2728,8 +2729,11 @@ class _DelayedExpr(Expr): # TODO _parameters = ["obj"] - def __init__(self, obj): + def __init__(self, obj, _branch_id=None): self.obj = obj + if _branch_id is None: + _branch_id = BranchId(0) + self._branch_id = _branch_id self.operands = [obj] def __str__(self): @@ -2758,18 +2762,29 @@ def normalize_expression(expr): return expr._name -def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr: +def optimize_until( + expr: Expr, stage: core.OptimizerStage, common_subplan_elimination: bool = False +) -> Expr: result = expr if stage == "logical": return result - # Simplify - expr = result.simplify() + while True: + if not common_subplan_elimination: + out = result.rewrite("reuse", cache={}) + else: + out = result + out = out.simplify() + if out._name == result._name or common_subplan_elimination: + break + result = out + + expr = out if stage == "simplified-logical": return expr # Manipulate Expression to make it more efficient - expr = expr.rewrite(kind="tune") + expr = expr.rewrite(kind="tune", cache={}) if stage == "tuned-logical": return expr @@ -2791,7 +2806,9 @@ def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr: raise ValueError(f"Stage {stage!r} not supported.") -def optimize(expr: Expr, fuse: bool = True) -> Expr: +def optimize( + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = False +) -> Expr: """High level query optimization This leverages three optimization passes: @@ -2805,6 +2822,10 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: Input expression to optimize fuse: whether or not to turn on blockwise fusion + common_subplan_elimination : bool, default False + whether we want to reuse common subplans that are found in the graph and + are used in self-joins or similar which require all data be held in memory + at some point. Only set this to true if your dataset fits into memory. See Also -------- @@ -2813,7 +2834,7 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: """ stage: core.OptimizerStage = "fused" if fuse else "simplified-physical" - return optimize_until(expr, stage) + return optimize_until(expr, stage, common_subplan_elimination) def is_broadcastable(dfs, s): @@ -3462,7 +3483,7 @@ def __str__(self): @functools.cached_property def _name(self): - return f"{str(self)}-{_tokenize_deterministic(self.exprs)}" + return f"{str(self)}-{_tokenize_deterministic(self.exprs, self._branch_id)}" def _divisions(self): return self.exprs[0]._divisions() diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 27f75d19..286deafa 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -26,6 +26,7 @@ from dask.utils import M, apply, funcname from dask_expr._concat import Concat +from dask_expr._core import BranchId from dask_expr._expr import ( Blockwise, Expr, @@ -300,7 +301,7 @@ def _name(self): name = funcname(self.combine.__self__).lower() + "-tree" else: name = funcname(self.combine) - return name + "-" + _tokenize_deterministic(*self.operands) + return name + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def __dask_postcompute__(self): return toolz.first, () @@ -507,6 +508,48 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) + def _reuse_up(self, parent): + return + + def _substitute_branch_id(self, branch_id): + return self + + def _reuse_down(self): + if self._branch_id.branch_id != 0: + return + + from dask_expr.io import IO + + seen = set() + stack = self.dependencies() + counter, found_consumer = 1, False + + while stack: + node = stack.pop() + + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, IO): + found_consumer = True + continue + + if isinstance(node, ApplyConcatApply): + counter += 1 + continue + + stack.extend(node.dependencies()) + + if not found_consumer: + return + b_id = BranchId(counter) + result = type(self)(*self.operands, b_id) + out = result._bubble_branch_id_down() + if out is None: + return result + return type(out)(*out.operands, _branch_id=b_id) + class Unique(ApplyConcatApply): _parameters = ["frame", "split_every", "split_out", "shuffle_method"] diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 1b6b34fe..e06682e6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -48,7 +48,9 @@ def _divisions(self): @functools.cached_property def _name(self): return ( - self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + self.operand("name_prefix") + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) def _layer(self): @@ -103,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._expr._branch_id) ) @functools.cached_property @@ -173,10 +175,14 @@ def _name(self): return ( funcname(self.func).lower() + "-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) else: - return self.label + "-" + _tokenize_deterministic(*self.operands) + return ( + self.label + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @functools.cached_property def _meta(self): @@ -448,7 +454,11 @@ class FromPandasDivisions(FromPandas): @functools.cached_property def _name(self): - return "from_pd_divs" + "-" + _tokenize_deterministic(*self.operands) + return ( + "from_pd_divs" + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @property def _divisions_and_locations(self): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3..a1905331 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -501,7 +501,7 @@ def _name(self): return ( funcname(type(self)).lower() + "-" - + _tokenize_deterministic(self.checksum, *self.operands) + + _tokenize_deterministic(self.checksum, *self.operands, self._branch_id) ) @property diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1f24bfda..b04a93af 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -5,6 +5,8 @@ from dask import config from dask.dataframe.utils import assert_eq as dd_assert_eq +from dask_expr.io import IO + def _backend_name() -> str: return config.get("dataframe.backend", "pandas") @@ -39,3 +41,12 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs) + + +def _check_consumer_node(expr, expected, consumer_node=IO, branch_id_counter=None): + if branch_id_counter is None: + branch_id_counter = expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(consumer_node)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == branch_id_counter diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index c113432a..82178660 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -507,7 +507,7 @@ def test_diff(pdf, df, axis, periods): if axis in ("columns", 1): assert actual._name == actual.simplify()._name else: - assert actual.simplify()._name == expected.simplify()._name + assert actual.optimize()._name == expected.optimize()._name @pytest.mark.parametrize( @@ -942,7 +942,7 @@ def test_repr(df): s = (df["x"] + 1).sum(skipna=False).expr assert '["x"]' in str(s) or "['x']" in str(s) assert "+ 1" in str(s) - assert "sum(skipna=False)" in str(s) + assert "sum(skipna=False" in str(s) @xfail_gpu("combine_first not supported by cudf") @@ -1163,8 +1163,8 @@ def test_tail_repartition(df): def test_projection_stacking(df): result = df[["x", "y"]]["x"] - optimized = result.simplify() - expected = df["x"].simplify() + optimized = result.optimize() + expected = df["x"].optimize() assert optimized._name == expected._name @@ -1885,8 +1885,8 @@ def test_assign_simplify(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x", "new"]].simplify() - expected = df2[["x"]].assign(new=df2.x > 1).simplify() + result = df[["x", "new"]].optimize() + expected = df2[["x"]].assign(new=df2.x > 1).optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1897,8 +1897,8 @@ def test_assign_simplify_new_column_not_needed(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x"]].simplify() - expected = df2[["x"]].simplify() + result = df[["x"]].optimize() + expected = df2[["x"]].optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1909,8 +1909,8 @@ def test_assign_simplify_series(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df.new.simplify() - expected = df2[[]].assign(new=df2.x > 1).new.simplify() + result = df.new.optimize() + expected = df2[[]].assign(new=df2.x > 1).new.optimize() assert result._name == expected._name @@ -1928,7 +1928,16 @@ def test_assign_squash_together(df, pdf): df["a"] = 1 df["b"] = 2 result = df.simplify() - assert len([x for x in list(result.expr.walk()) if isinstance(x, expr.Assign)]) == 1 + assert ( + len( + [ + x + for x in list(df.optimize(fuse=False).expr.walk()) + if isinstance(x, expr.Assign) + ] + ) + == 1 + ) pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) @@ -1973,10 +1982,10 @@ def test_astype_categories(df): assert_eq(result.y._meta.cat.categories, pd.Index([UNKNOWN_CATEGORIES])) -def test_drop_simplify(df): +def test_drop_optimize(df): q = df.drop(columns=["x"])[["y"]] - result = q.simplify() - expected = df[["y"]].simplify() + result = q.optimize() + expected = df[["y"]].optimize() assert result._name == expected._name @@ -2064,6 +2073,7 @@ def test_filter_pushdown_unavailable(df): result = df[df.x > 5] + df.x.sum() result = result[["x"]] expected = df[["x"]][df.x > 5] + df.x.sum() + assert result.optimize()._name == expected.optimize()._name assert result.simplify()._name == expected.simplify()._name @@ -2076,6 +2086,7 @@ def test_filter_pushdown(df, pdf): df = df.rename_axis(index="hello") result = df[df.x > 5].simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name pdf["z"] = 1 df = from_pandas(pdf, npartitions=10) @@ -2084,6 +2095,7 @@ def test_filter_pushdown(df, pdf): df_opt = df[["x", "y"]] expected = df_opt[df_opt.x > 5].rename_axis(index="hello").simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name def test_shape(df, pdf): @@ -2433,13 +2445,13 @@ def test_reset_index_filter_pushdown(df): result = q[q > 5] expected = df["x"] expected = expected[expected > 5].reset_index(drop=True) - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name q = df.x.reset_index() result = q[q.x > 5] expected = df["x"] expected = expected[expected > 5].reset_index() - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_astype_filter_pushdown(df, pdf): diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py new file mode 100644 index 00000000..c9d481a3 --- /dev/null +++ b/dask_expr/tests/test_reuse.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import pytest + +from dask_expr import from_pandas +from dask_expr.tests._util import _backend_library, _check_consumer_node, assert_eq + +# Set DataFrame backend for this module +pd = _backend_library() + + +@pytest.fixture +def pdf(): + pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 1, "c": 1}) + pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions + yield pdf + + +@pytest.fixture +def df(pdf): + yield from_pandas(pdf, npartitions=10) + + +def test_reuse_everything_scalar_and_series(df, pdf): + df["new"] = 1 + df["new2"] = df["x"] + 1 + df["new3"] = df.x[df.x > 1] + df.x[df.x > 2] + + pdf["new"] = 1 + pdf["new2"] = pdf["x"] + 1 + pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] + assert_eq(df, pdf) + _check_consumer_node(df, 1) + + +def test_dont_reuse_reducer(df, pdf): + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + assert_eq(result, expected) + _check_consumer_node(result, 2) + + result = df + df.sum() + expected = pdf + pdf.sum() + assert_eq(result, expected, check_names=False) # pandas 2.2 bug + _check_consumer_node(result, 2) + + result = df.replace(1, 5) + rhs_1 = result.x + result.y.sum() + rhs_2 = result.b + result.a.sum() + result["new"] = rhs_1 + result["new2"] = rhs_2 + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + _check_consumer_node(result, 2) + + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + result["new2"] = result.b + result.a.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + _check_consumer_node(result, 3) + + result = df.replace(1, 5) + result["new"] = result.x + result.sum().dropna().prod() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.sum().dropna().prod() + assert_eq(result, expected) + _check_consumer_node(result, 2) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 36277622..40592e54 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum()._name + assert df.shuffle("x").sum().optimize()._name == df.sum().optimize()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -264,7 +264,7 @@ def test_set_index_repartition(df, pdf): assert_eq(result, pdf.set_index("x")) -def test_set_index_simplify(df, pdf): +def test_set_index_optimize(df, pdf): q = df.set_index("x")["y"].optimize(fuse=False) expected = df[["x", "y"]].set_index("x")["y"].optimize(fuse=False) assert q._name == expected._name @@ -697,18 +697,21 @@ def test_shuffle_filter_pushdown(pdf, meth): result = result[result.x > 5.0] expected = getattr(df[df.x > 5.0], meth)("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x") result = result[result.x > 5.0][["x", "y"]] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x")[["x", "y"]] result = result[result.x > 5.0] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name @pytest.mark.parametrize("meth", ["set_index", "sort_values"]) @@ -716,7 +719,7 @@ def test_sort_values_avoid_overeager_filter_pushdown(meth): pdf1 = pd.DataFrame({"a": [4, 2, 3], "b": [1, 2, 3]}) df = from_pandas(pdf1, npartitions=2) df = getattr(df, meth)("a") - df = df[df.b > 2] + df.b.sum() + df = df[df.b > 2] + df[df.b > 1] result = df.simplify() assert isinstance(result.expr.left, Filter) assert isinstance(result.expr.left.frame, BaseSetIndexSortValues) @@ -729,18 +732,21 @@ def test_set_index_filter_pushdown(): result = result[result.y == 1] expected = df[df.y == 1].set_index("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x") result = result[result.y == 1][["y"]] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x")[["y"]] result = result[result.y == 1] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_shuffle_index_shuffle(df):