-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
46ada30
commit 60e2782
Showing
3 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Any, Callable, TypeVar | ||
|
||
from beartype import beartype | ||
from beartype.roar import BeartypeCallHintParamViolation | ||
|
||
from cognite.client.exceptions import CogniteTypeError | ||
|
||
T_Callable = TypeVar("T_Callable", bound=Callable) | ||
T_Class = TypeVar("T_Class", bound=type) | ||
|
||
|
||
def runtime_type_checked(f: T_Callable) -> T_Callable: | ||
beartyped_f = beartype(f) | ||
|
||
def f_wrapped(*args: Any, **kwargs: Any) -> Any: | ||
try: | ||
return beartyped_f(*args, **kwargs) | ||
except BeartypeCallHintParamViolation as e: | ||
raise CogniteTypeError(e.args[0]) | ||
|
||
return f_wrapped # type: ignore [return-value] | ||
|
||
|
||
def runtime_type_checked_public_methods(c: T_Class) -> T_Class: | ||
for name in dir(c): | ||
if not name.startswith("_"): | ||
setattr(c, name, runtime_type_checked(getattr(c, name))) | ||
return c |
103 changes: 103 additions & 0 deletions
103
tests/tests_unit/test_utils/test_runtime_type_checking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from typing import Union, overload, List | ||
|
||
import pytest | ||
|
||
from cognite.client.exceptions import CogniteTypeError | ||
from cognite.client.utils._runtime_type_checking import runtime_type_checked_public_methods | ||
|
||
|
||
class Foo: | ||
... | ||
|
||
|
||
class TestTypes: | ||
@runtime_type_checked_public_methods | ||
class Types: | ||
def primitive(self, x: int) -> None: | ||
... | ||
|
||
def list(self, x: List[str]) -> None: | ||
... | ||
|
||
def custom_class(self, x: Foo) -> None: | ||
... | ||
|
||
def test_primitive(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.Types().primitive("1") | ||
|
||
self.Types().primitive(1) | ||
|
||
def test_list(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " | ||
"violates type hint typing.List[str], as str '1' not instance of list." | ||
), | ||
): | ||
self.Types().list("1") | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " | ||
"violates type hint typing.List[str], as list index 0 item int 1 not instance of str." | ||
), | ||
): | ||
self.Types().list([1]) | ||
|
||
self.Types().list(["ok"]) | ||
|
||
def test_custom_type(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " | ||
"parameter x='1' violates type hint " | ||
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance " | ||
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">' | ||
), | ||
): | ||
self.Types().custom_class("1") | ||
|
||
self.Types().custom_class(Foo()) | ||
|
||
|
||
class TestOverloads: | ||
@runtime_type_checked_public_methods | ||
class WithOverload: | ||
@overload | ||
def foo(self, x: int, y: int) -> str: | ||
... | ||
|
||
@overload | ||
def foo(self, x: str, y: str) -> str: | ||
... | ||
|
||
def foo(self, x: Union[int, str], y: Union[int, str]) -> str: | ||
return f"{x}{y}" | ||
|
||
def test_overloads( | ||
self, | ||
) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " | ||
"parameter y=1.0 violates type hint typing.Union[int, str], as float 1.0 not int or str." | ||
), | ||
): | ||
self.WithOverload().foo(1, 1.0) | ||
|
||
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet | ||
self.WithOverload().foo(1, "1") |