Skip to content

Commit

Permalink
Added Result filter methods (#224)
Browse files Browse the repository at this point in the history
* Added Result filter methods

* PR update: Use a different error value in the tests

* PR updates

* Added pipeable functions for some methods that were missing them
* Added tests for `Result#map_error`
* Added tests for piping functions
  • Loading branch information
brendanmaguire authored Aug 29, 2024
1 parent 172bb4b commit 65cd3a9
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
80 changes: 80 additions & 0 deletions expression/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,38 @@ def is_ok(self) -> bool:
"""Return `True` if the result is an `Ok` value."""
return self.tag == "ok"

def filter(self, predicate: Callable[[_TSource], bool], default: _TError) -> Result[_TSource, _TError]:
"""Filter result.
Returns the input if the predicate evaluates to true, otherwise
returns the `default`
"""
match self:
case Result(tag="ok", ok=value) if predicate(value):
return self
case Result(tag="error"):
return self
case _:
return Error(default)

def filter_with(
self,
predicate: Callable[[_TSource], bool],
default: Callable[[_TSource], _TError],
) -> Result[_TSource, _TError]:
"""Filter result.
Returns the input if the predicate evaluates to true, otherwise
returns the `default` using the value as input
"""
match self:
case Result(tag="ok", ok=value) if predicate(value):
return self
case Result(tag="ok", ok=value):
return Error(default(value))
case Result():
return self

def dict(self) -> builtins.dict[str, _TSource | _TError | Literal["ok", "error"]]:
"""Return a json serializable representation of the result."""
match self:
Expand Down Expand Up @@ -352,6 +384,11 @@ def map2(
return x.map2(y, mapper)


@curry_flip(1)
def map_error(result: Result[_TSource, _TError], mapper: Callable[[_TError], _TResult]) -> Result[_TSource, _TResult]:
return result.map_error(mapper)


@curry_flip(1)
def bind(
result: Result[_TSource, _TError],
Expand All @@ -374,11 +411,46 @@ def is_error(result: Result[_TSource, _TError]) -> TypeGuard[Result[_TSource, _T
return result.is_error()


@curry_flip(1)
def filter(
result: Result[_TSource, _TError],
predicate: Callable[[_TSource], bool],
default: _TError,
) -> Result[_TSource, _TError]:
return result.filter(predicate, default)


@curry_flip(1)
def filter_with(
result: Result[_TSource, _TError],
predicate: Callable[[_TSource], bool],
default: Callable[[_TSource], _TError],
) -> Result[_TSource, _TError]:
return result.filter_with(predicate, default)


def swap(result: Result[_TSource, _TError]) -> Result[_TError, _TSource]:
"""Swaps the value in the result so an Ok becomes an Error and an Error becomes an Ok."""
return result.swap()


@curry_flip(1)
def or_else(result: Result[_TSource, _TError], other: Result[_TSource, _TError]) -> Result[_TSource, _TError]:
return result.or_else(other)


@curry_flip(1)
def or_else_with(
result: Result[_TSource, _TError],
other: Callable[[_TError], Result[_TSource, _TError]],
) -> Result[_TSource, _TError]:
return result.or_else_with(other)


def merge(result: Result[_TSource, _TSource]) -> _TSource:
return result.merge()


def to_option(result: Result[_TSource, Any]) -> Option[_TSource]:
from expression.core.option import Nothing, Some

Expand Down Expand Up @@ -406,9 +478,17 @@ def of_option_with(value: Option[_TSource], error: Callable[[], _TError]) -> Res
"map",
"bind",
"dict",
"filter",
"filter_with",
"is_ok",
"is_error",
"map2",
"map_error",
"merge",
"to_option",
"of_option",
"of_option_with",
"or_else",
"or_else_with",
"swap",
]
68 changes: 68 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def test_result_error_chained_map(msg: str, y: int):
case _:
assert False

@given(st.text())
def test_map_error(msg: str):
assert Error(msg).map_error(lambda x: f"more {x}") == Error("more " + msg)

@given(st.text())
def test_map_error_piped(msg: str):
assert Error(msg).pipe(result.map_error(lambda x: f"more {x}")) == Error(f"more {msg}")


@given(st.integers(), st.integers()) # type: ignore
def test_result_bind_piped(x: int, y: int):
Expand Down Expand Up @@ -362,6 +370,54 @@ def test_pipeline_error():
assert hn(42) == error


def test_filter_ok_passing_predicate():
xs: Result[int, str] = Ok(42)
ys = xs.filter(lambda x: x > 10, "error")

assert ys == xs


def test_filter_ok_failing_predicate():
xs: Result[int, str] = Ok(5)
ys = xs.filter(lambda x: x > 10, "error")

assert ys == Error("error")


def test_filter_error():
error = Error("original error")
ys = error.filter(lambda x: x > 10, "error")

assert ys == error

def test_filter_piped():
assert Ok(42).pipe(result.filter(lambda x: x > 10, "error")) == Ok(42)


def test_filter_with_ok_passing_predicate():
xs: Result[int, str] = Ok(42)
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == xs


def test_filter_with_ok_failing_predicate():
xs: Result[int, str] = Ok(5)
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == Error("error 5")


def test_filter_with_error():
error = Error("original error")
ys = error.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == error

def test_filter_with_piped():
assert Ok(42).pipe(result.filter_with(lambda x: x > 10, lambda value: f"error {value}")) == Ok(42)


class MyError(BaseModel):
message: str

Expand Down Expand Up @@ -525,6 +581,8 @@ def test_result_swap_with_error():
xs = result.swap(error)
assert xs == Ok(1)

def test_swap_piped():
assert Ok(42).pipe(result.swap) == Error(42)

def test_ok_or_else_ok():
xs: Result[int, str] = Ok(42)
Expand All @@ -549,6 +607,8 @@ def test_error_or_else_error():
ys = xs.or_else(Error("new error"))
assert ys == Error("new error")

def test_or_else_piped():
assert Ok(42).pipe(result.or_else(Ok(0))) == Ok(42)

def test_ok_or_else_with_ok():
xs: Result[str, str] = Ok("good")
Expand All @@ -574,6 +634,10 @@ def test_error_or_else_with_error():
assert ys == Error("new error from original error")


def test_or_else_with_piped():
assert Ok(42).pipe(result.or_else_with(lambda _: Ok(0))) == Ok(42)


def test_merge_ok():
assert Result.Ok(42).merge() == 42

Expand Down Expand Up @@ -601,3 +665,7 @@ class Child2(Parent):
def test_merge_subclasses():
xs: Result[Parent, Parent] = Result.Ok(Child1(x=42))
assert xs.merge() == Child1(x=42)


def test_merge_piped():
assert Ok(42).pipe(result.merge) == 42

0 comments on commit 65cd3a9

Please sign in to comment.