Skip to content

Commit

Permalink
Add par_map methods to collections
Browse files Browse the repository at this point in the history
* `par_map` facilitates running an async function across all items in the collection
* Resolves #213
  • Loading branch information
brendanmaguire committed Aug 30, 2024
1 parent 65cd3a9 commit 49ebb17
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 4 deletions.
30 changes: 29 additions & 1 deletion expression/collections/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from __future__ import annotations

import array
import asyncio
import builtins
import functools
from collections.abc import Callable, Iterable, Iterator, MutableSequence
from collections.abc import Awaitable, Callable, Iterable, Iterator, MutableSequence
from enum import Enum
from typing import Any, TypeVar, cast

Expand Down Expand Up @@ -185,9 +186,36 @@ def __init__(
self.typecode = typecode

def map(self, mapping: Callable[[_TSource], _TResult]) -> TypedArray[_TResult]:
"""Map array.
Builds a new array whose elements are the results of applying
the given function to each of the elements of the array.
Args:
mapping: A function to transform items from the input array.
Returns:
The result sequence.
"""
result = builtins.map(mapping, self.value)
return TypedArray(result)

async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> TypedArray[_TResult]:
"""Map array asynchronously.
Builds a new array whose elements are the results of applying
the given asynchronous function to each of the elements of the
array.
Args:
mapping: A function to transform items from the input array.
Returns:
The result sequence.
"""
result = await asyncio.gather(*[mapping(item) for item in self])
return TypedArray(result)

def choose(self, chooser: Callable[[_TSource], Option[_TResult]]) -> TypedArray[_TResult]:
"""Choose items from the list.
Expand Down
20 changes: 19 additions & 1 deletion expression/collections/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@

from __future__ import annotations

import asyncio
import builtins
import functools
import itertools
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args, overload

from typing_extensions import TypeVarTuple, Unpack
Expand Down Expand Up @@ -239,6 +240,23 @@ def map(self, mapping: Callable[[_TSource], _TResult]) -> Block[_TResult]:
"""
return Block((*builtins.map(mapping, self),))

async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> Block[_TResult]:
"""Map list asynchronously.
Builds a new collection whose elements are the results of
applying the given asynchronous function to each of the
elements of the collection.
Args:
mapping: The function to transform elements from the input
list.
Returns:
The list of transformed elements.
"""
result = await asyncio.gather(*[mapping(item) for item in self])
return Block(result)

def starmap(self: Block[tuple[Unpack[_P]]], mapping: Callable[[Unpack[_P]], _TResult]) -> Block[_TResult]:
"""Starmap source sequence.
Expand Down
22 changes: 21 additions & 1 deletion expression/collections/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# - https://github.com/fsharp/fsharp/blob/master/src/fsharp/FSharp.Core/map.fs
from __future__ import annotations

from collections.abc import Callable, ItemsView, Iterable, Iterator, Mapping
import asyncio
from collections.abc import Awaitable, Callable, ItemsView, Iterable, Iterator, Mapping
from typing import Any, TypeVar, cast

from expression.core import Option, PipeMixin, SupportsLessThan, curry_flip, pipe
Expand Down Expand Up @@ -114,6 +115,25 @@ def map(self, mapping: Callable[[_Key, _Value], _Result]) -> Map[_Key, _Result]:
"""
return Map(maptree.map(mapping, self._tree))

async def par_map(self, mapping: Callable[[_Key, _Value], Awaitable[_Result]]) -> Map[_Key, _Result]:
"""Map the mapping asynchronously.
Builds a new collection whose elements are the results of
applying the given asynchronous function to each of the elements
of the collection. The key passed to the function indicates the
key of element being transformed.
Args:
mapping: The function to transform the key/value pairs
Returns:
The resulting map of keys and transformed values.
"""
keys_and_values = self.to_seq()
result = await asyncio.gather(*(mapping(key, value) for key, value in keys_and_values))
keys = [key for key, _ in keys_and_values]
return Map.of_seq(zip(keys, result))

def partition(self, predicate: Callable[[_Key, _Value], bool]) -> tuple[Map[_Key, _Value], Map[_Key, _Value]]:
r1, r2 = maptree.partition(predicate, self._tree)
return Map(r1), Map(r2)
Expand Down
19 changes: 18 additions & 1 deletion expression/collections/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

from __future__ import annotations

import asyncio
import builtins
import functools
import itertools
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Awaitable, Callable, Iterable, Iterator
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

from expression.core import (
Expand Down Expand Up @@ -175,6 +176,22 @@ def map(self, mapper: Callable[[_TSource], _TResult]) -> Seq[_TResult]:
"""
return Seq(pipe(self, map(mapper)))

async def par_map(self, mapper: Callable[[_TSource], Awaitable[_TResult]]) -> Seq[_TResult]:
"""Map sequence asynchronously.
Builds a new collection whose elements are the results of
applying the given asynchronous function to each of the elements
of the collection.
Args:
mapper: A function to transform items from the input sequence.
Returns:
The result sequence.
"""
result = await asyncio.gather(*[mapper(item) for item in self])
return Seq(result)

@overload
def starmap(self: Seq[tuple[_T1, _T2]], mapping: Callable[[_T1, _T2], _TResult]) -> Seq[_TResult]: ...

Expand Down
18 changes: 18 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import functools
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -428,3 +429,20 @@ def test_array_monad_law_associativity_iterable(xs: list[int]):

m = array.of_seq(xs)
assert m.collect(f).collect(g) == m.collect(lambda x: f(x).collect(g))

@pytest.mark.asyncio
async def test_par_map():
async def async_fn(i: int):
await asyncio.sleep(0.1)
return i * 2

xs = TypedArray(range(1, 10))

start_time = asyncio.get_event_loop().time()
ys = await xs.par_map(async_fn)
end_time = asyncio.get_event_loop().time()

assert ys == TypedArray(i * 2 for i in range(1, 10))

time_taken = end_time - start_time
assert time_taken < 0.2, "par_map took too long"
19 changes: 19 additions & 0 deletions tests/test_block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import functools
from builtins import list as list
from collections.abc import Callable
Expand All @@ -7,6 +8,7 @@
from hypothesis import strategies as st
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
import pytest

from expression import Nothing, Option, Some, pipe
from expression.collections import Block, block
Expand Down Expand Up @@ -458,3 +460,20 @@ def test_serialize_block_works():
assert model_.annotated_type_empty == block.empty
assert model_.custom_type == Block(["a", "b", "c"])
assert model_.custom_type_empty == block.empty

@pytest.mark.asyncio
async def test_par_map():
async def async_fn(i: int):
await asyncio.sleep(0.1)
return i * 2

xs = Block(range(1, 10))

start_time = asyncio.get_event_loop().time()
ys = await xs.par_map(async_fn)
end_time = asyncio.get_event_loop().time()

assert ys == Block(i * 2 for i in range(1, 10))

time_taken = end_time - start_time
assert time_taken < 0.2, "par_map took too long"
20 changes: 20 additions & 0 deletions tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from collections.abc import Callable, ItemsView, Iterable

import pytest
from hypothesis import given # type: ignore
from hypothesis import strategies as st

Expand Down Expand Up @@ -150,3 +152,21 @@ def test_expression_issue_105():
m = m.add("1", 1).add("2", 2).add("3", 3).add("4", 4)
m = m.change("2", lambda x: x)
m = m.change("3", lambda x: x)


@pytest.mark.asyncio
async def test_par_map():
async def async_fn(key: str, value: int) -> int:
await asyncio.sleep(0.1)
return int(key) * value

xs = Map.of_seq((str(i), i) for i in range(1, 10))

start_time = asyncio.get_event_loop().time()
ys = await xs.par_map(async_fn)
end_time = asyncio.get_event_loop().time()

assert ys == Map.of_seq((str(i), i * i) for i in range(1, 10))

time_taken = end_time - start_time
assert time_taken < 0.2, "par_map took too long"
19 changes: 19 additions & 0 deletions tests/test_seq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import functools
from collections.abc import Callable, Iterable
from itertools import accumulate
Expand Down Expand Up @@ -382,3 +383,21 @@ def test_seq_monad_law_associativity_empty(value: int):
# Empty list
m = empty
assert list(m.collect(f).collect(g)) == list(m.collect(lambda x: f(x).collect(g)))


@pytest.mark.asyncio
async def test_par_map():
async def async_fn(i: int):
await asyncio.sleep(0.1)
return i * 2

xs = seq.of_iterable(range(1, 10))

start_time = asyncio.get_event_loop().time()
ys = await xs.par_map(async_fn)
end_time = asyncio.get_event_loop().time()

assert list(ys) == [i * 2 for i in range(1, 10)]

time_taken = end_time - start_time
assert time_taken < 0.2, "par_map took too long"

0 comments on commit 49ebb17

Please sign in to comment.