diff --git a/expression/core/result.py b/expression/core/result.py index e352d29..7690445 100644 --- a/expression/core/result.py +++ b/expression/core/result.py @@ -150,6 +150,38 @@ def is_ok(self) -> bool: """Return `True` if the result is an `Ok` value.""" return self.tag == "ok" + def filter(self, predicate: Callable[[_TSource], bool], default: _TError) -> Result[_TSource, _TError]: + """Filter result. + + Returns the input if the predicate evaluates to true, otherwise + returns the `default` + """ + match self: + case Result(tag="ok", ok=value) if predicate(value): + return self + case Result(tag="error"): + return self + case _: + return Error(default) + + def filter_with( + self, + predicate: Callable[[_TSource], bool], + default: Callable[[_TSource], _TError], + ) -> Result[_TSource, _TError]: + """Filter result. + + Returns the input if the predicate evaluates to true, otherwise + returns the `default` using the value as input + """ + match self: + case Result(tag="ok", ok=value) if predicate(value): + return self + case Result(tag="ok", ok=value): + return Error(default(value)) + case Result(): + return self + def dict(self) -> builtins.dict[str, _TSource | _TError | Literal["ok", "error"]]: """Return a json serializable representation of the result.""" match self: @@ -352,6 +384,11 @@ def map2( return x.map2(y, mapper) +@curry_flip(1) +def map_error(result: Result[_TSource, _TError], mapper: Callable[[_TError], _TResult]) -> Result[_TSource, _TResult]: + return result.map_error(mapper) + + @curry_flip(1) def bind( result: Result[_TSource, _TError], @@ -374,11 +411,46 @@ def is_error(result: Result[_TSource, _TError]) -> TypeGuard[Result[_TSource, _T return result.is_error() +@curry_flip(1) +def filter( + result: Result[_TSource, _TError], + predicate: Callable[[_TSource], bool], + default: _TError, +) -> Result[_TSource, _TError]: + return result.filter(predicate, default) + + +@curry_flip(1) +def filter_with( + result: Result[_TSource, _TError], + predicate: Callable[[_TSource], bool], + default: Callable[[_TSource], _TError], +) -> Result[_TSource, _TError]: + return result.filter_with(predicate, default) + + def swap(result: Result[_TSource, _TError]) -> Result[_TError, _TSource]: """Swaps the value in the result so an Ok becomes an Error and an Error becomes an Ok.""" return result.swap() +@curry_flip(1) +def or_else(result: Result[_TSource, _TError], other: Result[_TSource, _TError]) -> Result[_TSource, _TError]: + return result.or_else(other) + + +@curry_flip(1) +def or_else_with( + result: Result[_TSource, _TError], + other: Callable[[_TError], Result[_TSource, _TError]], +) -> Result[_TSource, _TError]: + return result.or_else_with(other) + + +def merge(result: Result[_TSource, _TSource]) -> _TSource: + return result.merge() + + def to_option(result: Result[_TSource, Any]) -> Option[_TSource]: from expression.core.option import Nothing, Some @@ -406,9 +478,17 @@ def of_option_with(value: Option[_TSource], error: Callable[[], _TError]) -> Res "map", "bind", "dict", + "filter", + "filter_with", "is_ok", "is_error", + "map2", + "map_error", + "merge", "to_option", "of_option", "of_option_with", + "or_else", + "or_else_with", + "swap", ] diff --git a/tests/test_result.py b/tests/test_result.py index 909eeba..5a0ea2a 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -194,6 +194,14 @@ def test_result_error_chained_map(msg: str, y: int): case _: assert False +@given(st.text()) +def test_map_error(msg: str): + assert Error(msg).map_error(lambda x: f"more {x}") == Error("more " + msg) + +@given(st.text()) +def test_map_error_piped(msg: str): + assert Error(msg).pipe(result.map_error(lambda x: f"more {x}")) == Error(f"more {msg}") + @given(st.integers(), st.integers()) # type: ignore def test_result_bind_piped(x: int, y: int): @@ -362,6 +370,54 @@ def test_pipeline_error(): assert hn(42) == error +def test_filter_ok_passing_predicate(): + xs: Result[int, str] = Ok(42) + ys = xs.filter(lambda x: x > 10, "error") + + assert ys == xs + + +def test_filter_ok_failing_predicate(): + xs: Result[int, str] = Ok(5) + ys = xs.filter(lambda x: x > 10, "error") + + assert ys == Error("error") + + +def test_filter_error(): + error = Error("original error") + ys = error.filter(lambda x: x > 10, "error") + + assert ys == error + +def test_filter_piped(): + assert Ok(42).pipe(result.filter(lambda x: x > 10, "error")) == Ok(42) + + +def test_filter_with_ok_passing_predicate(): + xs: Result[int, str] = Ok(42) + ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}") + + assert ys == xs + + +def test_filter_with_ok_failing_predicate(): + xs: Result[int, str] = Ok(5) + ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}") + + assert ys == Error("error 5") + + +def test_filter_with_error(): + error = Error("original error") + ys = error.filter_with(lambda x: x > 10, lambda value: f"error {value}") + + assert ys == error + +def test_filter_with_piped(): + assert Ok(42).pipe(result.filter_with(lambda x: x > 10, lambda value: f"error {value}")) == Ok(42) + + class MyError(BaseModel): message: str @@ -525,6 +581,8 @@ def test_result_swap_with_error(): xs = result.swap(error) assert xs == Ok(1) +def test_swap_piped(): + assert Ok(42).pipe(result.swap) == Error(42) def test_ok_or_else_ok(): xs: Result[int, str] = Ok(42) @@ -549,6 +607,8 @@ def test_error_or_else_error(): ys = xs.or_else(Error("new error")) assert ys == Error("new error") +def test_or_else_piped(): + assert Ok(42).pipe(result.or_else(Ok(0))) == Ok(42) def test_ok_or_else_with_ok(): xs: Result[str, str] = Ok("good") @@ -574,6 +634,10 @@ def test_error_or_else_with_error(): assert ys == Error("new error from original error") +def test_or_else_with_piped(): + assert Ok(42).pipe(result.or_else_with(lambda _: Ok(0))) == Ok(42) + + def test_merge_ok(): assert Result.Ok(42).merge() == 42 @@ -601,3 +665,7 @@ class Child2(Parent): def test_merge_subclasses(): xs: Result[Parent, Parent] = Result.Ok(Child1(x=42)) assert xs.merge() == Child1(x=42) + + +def test_merge_piped(): + assert Ok(42).pipe(result.merge) == 42