Skip to content

Commit

Permalink
POC for runtime type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
erlendvollset committed Dec 18, 2023
1 parent 46ada30 commit 60e2782
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cognite/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class CogniteException(Exception):
pass


class CogniteTypeError(CogniteException):
...


@dataclass
class GraphQLErrorSpec:
message: str
Expand Down
28 changes: 28 additions & 0 deletions cognite/client/utils/_runtime_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Callable, TypeVar

from beartype import beartype
from beartype.roar import BeartypeCallHintParamViolation

from cognite.client.exceptions import CogniteTypeError

T_Callable = TypeVar("T_Callable", bound=Callable)
T_Class = TypeVar("T_Class", bound=type)


def runtime_type_checked(f: T_Callable) -> T_Callable:
beartyped_f = beartype(f)

def f_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return beartyped_f(*args, **kwargs)
except BeartypeCallHintParamViolation as e:
raise CogniteTypeError(e.args[0])

return f_wrapped # type: ignore [return-value]


def runtime_type_checked_public_methods(c: T_Class) -> T_Class:
for name in dir(c):
if not name.startswith("_"):
setattr(c, name, runtime_type_checked(getattr(c, name)))
return c
103 changes: 103 additions & 0 deletions tests/tests_unit/test_utils/test_runtime_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import re
from typing import Union, overload, List

import pytest

from cognite.client.exceptions import CogniteTypeError
from cognite.client.utils._runtime_type_checking import runtime_type_checked_public_methods


class Foo:
...


class TestTypes:
@runtime_type_checked_public_methods
class Types:
def primitive(self, x: int) -> None:
...

def list(self, x: List[str]) -> None:
...

def custom_class(self, x: Foo) -> None:
...

def test_primitive(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
self.Types().primitive("1")

self.Types().primitive(1)

def test_list(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
"violates type hint typing.List[str], as str '1' not instance of list."
),
):
self.Types().list("1")

with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
"violates type hint typing.List[str], as list index 0 item int 1 not instance of str."
),
):
self.Types().list([1])

self.Types().list(["ok"])

def test_custom_type(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
"parameter x='1' violates type hint "
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
),
):
self.Types().custom_class("1")

self.Types().custom_class(Foo())


class TestOverloads:
@runtime_type_checked_public_methods
class WithOverload:
@overload
def foo(self, x: int, y: int) -> str:
...

@overload
def foo(self, x: str, y: str) -> str:
...

def foo(self, x: Union[int, str], y: Union[int, str]) -> str:
return f"{x}{y}"

def test_overloads(
self,
) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
"parameter y=1.0 violates type hint typing.Union[int, str], as float 1.0 not int or str."
),
):
self.WithOverload().foo(1, 1.0)

# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
self.WithOverload().foo(1, "1")

0 comments on commit 60e2782

Please sign in to comment.