From 60e278257f2c1317bfceadde8e8f60345a382438 Mon Sep 17 00:00:00 2001 From: erlendvollset Date: Mon, 18 Dec 2023 17:16:04 +0100 Subject: [PATCH] POC for runtime type checking --- cognite/client/exceptions.py | 4 + .../client/utils/_runtime_type_checking.py | 28 +++++ .../test_utils/test_runtime_type_checking.py | 103 ++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 cognite/client/utils/_runtime_type_checking.py create mode 100644 tests/tests_unit/test_utils/test_runtime_type_checking.py diff --git a/cognite/client/exceptions.py b/cognite/client/exceptions.py index dc66ad7936..1a9b8f518b 100644 --- a/cognite/client/exceptions.py +++ b/cognite/client/exceptions.py @@ -15,6 +15,10 @@ class CogniteException(Exception): pass +class CogniteTypeError(CogniteException): + ... + + @dataclass class GraphQLErrorSpec: message: str diff --git a/cognite/client/utils/_runtime_type_checking.py b/cognite/client/utils/_runtime_type_checking.py new file mode 100644 index 0000000000..3614c774c4 --- /dev/null +++ b/cognite/client/utils/_runtime_type_checking.py @@ -0,0 +1,28 @@ +from typing import Any, Callable, TypeVar + +from beartype import beartype +from beartype.roar import BeartypeCallHintParamViolation + +from cognite.client.exceptions import CogniteTypeError + +T_Callable = TypeVar("T_Callable", bound=Callable) +T_Class = TypeVar("T_Class", bound=type) + + +def runtime_type_checked(f: T_Callable) -> T_Callable: + beartyped_f = beartype(f) + + def f_wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return beartyped_f(*args, **kwargs) + except BeartypeCallHintParamViolation as e: + raise CogniteTypeError(e.args[0]) + + return f_wrapped # type: ignore [return-value] + + +def runtime_type_checked_public_methods(c: T_Class) -> T_Class: + for name in dir(c): + if not name.startswith("_"): + setattr(c, name, runtime_type_checked(getattr(c, name))) + return c diff --git a/tests/tests_unit/test_utils/test_runtime_type_checking.py b/tests/tests_unit/test_utils/test_runtime_type_checking.py new file mode 100644 index 0000000000..3a79b3bc99 --- /dev/null +++ b/tests/tests_unit/test_utils/test_runtime_type_checking.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import re +from typing import Union, overload, List + +import pytest + +from cognite.client.exceptions import CogniteTypeError +from cognite.client.utils._runtime_type_checking import runtime_type_checked_public_methods + + +class Foo: + ... + + +class TestTypes: + @runtime_type_checked_public_methods + class Types: + def primitive(self, x: int) -> None: + ... + + def list(self, x: List[str]) -> None: + ... + + def custom_class(self, x: Foo) -> None: + ... + + def test_primitive(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.Types().primitive("1") + + self.Types().primitive(1) + + def test_list(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " + "violates type hint typing.List[str], as str '1' not instance of list." + ), + ): + self.Types().list("1") + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " + "violates type hint typing.List[str], as list index 0 item int 1 not instance of str." + ), + ): + self.Types().list([1]) + + self.Types().list(["ok"]) + + def test_custom_type(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " + "parameter x='1' violates type hint " + ", as str '1' not instance " + 'of ' + ), + ): + self.Types().custom_class("1") + + self.Types().custom_class(Foo()) + + +class TestOverloads: + @runtime_type_checked_public_methods + class WithOverload: + @overload + def foo(self, x: int, y: int) -> str: + ... + + @overload + def foo(self, x: str, y: str) -> str: + ... + + def foo(self, x: Union[int, str], y: Union[int, str]) -> str: + return f"{x}{y}" + + def test_overloads( + self, + ) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " + "parameter y=1.0 violates type hint typing.Union[int, str], as float 1.0 not int or str." + ), + ): + self.WithOverload().foo(1, 1.0) + + # Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet + self.WithOverload().foo(1, "1")