Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Oct 31, 2023
1 parent 0204f81 commit 08b3344
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import copy
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
from dask.layers import DataFrameTreeReduction
from typing_extensions import TypeAlias

from dask_awkward.utils import LazyInputsDict

Expand All @@ -15,6 +16,9 @@
from awkward._nplikes.typetracer import TypeTracerReport


BackendT: TypeAlias = Union[Literal["cpu"], Literal["jax"], Literal["cuda"]]


class AwkwardBlockwiseLayer(Blockwise):
"""Just like upstream Blockwise, except we override pickling"""

Expand Down Expand Up @@ -55,7 +59,7 @@ class ImplementsMocking(Protocol):
def mock(self) -> AwkwardArray:
...

def mock_empty(self, backend: str) -> AwkwardArray:
def mock_empty(self, backend: BackendT) -> AwkwardArray:
...


Expand Down
13 changes: 12 additions & 1 deletion src/dask_awkward/lib/io/columnar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from awkward import Array as AwkwardArray
from awkward.forms import Form

from dask_awkward.layers.layers import ImplementsIOFunction, ImplementsNecessaryColumns
from dask_awkward.layers.layers import (
BackendT,
ImplementsIOFunction,
ImplementsNecessaryColumns,
)
from dask_awkward.lib.utils import (
METADATA_ATTRIBUTES,
FormStructure,
Expand Down Expand Up @@ -58,6 +62,13 @@ class ColumnProjectionMixin(ImplementsNecessaryColumns[FormStructure]):
def mock(self: S) -> AwkwardArray:
return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior)

def mock_empty(self: S, backend: BackendT = "cpu") -> AwkwardArray:
return ak.to_backend(
self.form.length_zero_array(highlevel=False),
backend,
highlevel=True,
)

def prepare_for_projection(
self: S,
) -> tuple[AwkwardArray, TypeTracerReport, FormStructure]:
Expand Down

0 comments on commit 08b3344

Please sign in to comment.