From c9cc99e42fe9b23dcd067ccd4d411f5a3d4f5342 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Thu, 1 Aug 2024 15:55:43 -0400 Subject: [PATCH] More Type Annotations (#177) * More type annotations. * Fix Python 3.8 issues. More typing. * More Python 3.8 fixes --- dbt_common/context.py | 4 ++-- dbt_common/contracts/config/base.py | 12 ++++++------ dbt_common/contracts/util.py | 7 ++++--- dbt_common/dataclass_schema.py | 2 +- dbt_common/events/contextvars.py | 2 +- dbt_common/exceptions/base.py | 26 +++++++++++++------------- 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/dbt_common/context.py b/dbt_common/context.py index 947d409a..a7722139 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -6,7 +6,7 @@ from dbt_common.record import Recorder -class CaseInsensitiveMapping(Mapping): +class CaseInsensitiveMapping(Mapping[str, str]): def __init__(self, env: Mapping[str, str]): self._env = {k.casefold(): (k, v) for k, v in env.items()} @@ -65,7 +65,7 @@ def env_secrets(self) -> List[str]: def reliably_get_invocation_var() -> ContextVar[InvocationContext]: - invocation_var: Optional[ContextVar] = next( + invocation_var: Optional[ContextVar[InvocationContext]] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/contracts/config/base.py b/dbt_common/contracts/config/base.py index 42acb1bf..df8afb24 100644 --- a/dbt_common/contracts/config/base.py +++ b/dbt_common/contracts/config/base.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, Field from itertools import chain -from typing import Callable, Dict, Any, List, TypeVar, Type +from typing import Any, Callable, Dict, Iterator, List, Type, TypeVar from dbt_common.contracts.config.metadata import Metadata from dbt_common.exceptions import CompilationError, DbtInternalError @@ -45,7 +45,7 @@ def __delitem__(self, key: str) -> None: else: del self._extra[key] - def _content_iterator(self, include_condition: Callable[[Field], bool]): + def _content_iterator(self, include_condition: Callable[[Field[Any]], bool]) -> Iterator[str]: seen = set() for fld, _ in self._get_fields(): seen.add(fld.name) @@ -57,7 +57,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]): seen.add(key) yield key - def __iter__(self): + def __iter__(self) -> Iterator[str]: yield from self._content_iterator(include_condition=lambda f: True) def __len__(self) -> int: @@ -76,7 +76,7 @@ def compare_key( elif key in unrendered and key not in other: return False else: - return unrendered[key] == other[key] + return bool(unrendered[key] == other[key]) @classmethod def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool: @@ -203,11 +203,11 @@ def metadata_key(cls) -> str: return "compare" @classmethod - def should_include(cls, fld: Field) -> bool: + def should_include(cls, fld: Field[Any]) -> bool: return cls.from_field(fld) == cls.Include -def _listify(value: Any) -> List: +def _listify(value: Any) -> List[Any]: if isinstance(value, list): return value[:] else: diff --git a/dbt_common/contracts/util.py b/dbt_common/contracts/util.py index 7ec02463..7bd26e3b 100644 --- a/dbt_common/contracts/util.py +++ b/dbt_common/contracts/util.py @@ -1,10 +1,11 @@ import dataclasses +from typing import Any # TODO: remove from dbt_common.contracts.util:: Replaceable + references class Replaceable: - def replace(self, **kwargs): - return dataclasses.replace(self, **kwargs) + def replace(self, **kwargs: Any): + return dataclasses.replace(self, **kwargs) # type: ignore class Mergeable(Replaceable): @@ -15,7 +16,7 @@ def merged(self, *args): replacements = {} cls = type(self) for arg in args: - for field in dataclasses.fields(cls): + for field in dataclasses.fields(cls): # type: ignore value = getattr(arg, field.name) if value is not None: replacements[field.name] = value diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 867d5a4c..4e003b13 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -92,7 +92,7 @@ def json_schema(cls): return json_schema @classmethod - def validate(cls, data): + def validate(cls, data: Any) -> None: json_schema = cls.json_schema() validator = jsonschema.Draft7Validator(json_schema) error = next(iter(validator.iter_errors(data)), None) diff --git a/dbt_common/events/contextvars.py b/dbt_common/events/contextvars.py index 1508cdcc..546675b4 100644 --- a/dbt_common/events/contextvars.py +++ b/dbt_common/events/contextvars.py @@ -22,7 +22,7 @@ def get_contextvars(prefix: str) -> Dict[str, Any]: return rv -def get_node_info(): +def get_node_info() -> Dict[str, Any]: cvars = get_contextvars(LOG_PREFIX) if "node_info" in cvars: return cvars["node_info"] diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index d966a28d..61bd97a9 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -23,7 +23,7 @@ class DbtBaseException(Exception): CODE = -32000 MESSAGE = "Server Error" - def data(self): + def data(self) -> Dict[str, Any]: # if overriding, make sure the result is json-serializable. return { "type": self.__class__.__name__, @@ -32,7 +32,7 @@ def data(self): class DbtInternalError(DbtBaseException): - def __init__(self, msg: str): + def __init__(self, msg: str) -> None: self.stack: List = [] self.msg = scrub_secrets(msg, env_secrets()) @@ -40,7 +40,7 @@ def __init__(self, msg: str): def type(self) -> str: return "Internal" - def process_stack(self): + def process_stack(self) -> List[str]: lines = [] stack = self.stack first = True @@ -81,7 +81,7 @@ def __init__(self, msg: str, node=None) -> None: self.node = node self.msg = scrub_secrets(msg, env_secrets()) - def add_node(self, node=None): + def add_node(self, node=None) -> None: if node is not None and node is not self.node: if self.node is not None: self.stack.append(self.node) @@ -91,7 +91,7 @@ def add_node(self, node=None): def type(self): return "Runtime" - def node_to_string(self, node: Any): + def node_to_string(self, node: Any) -> str: """Given a node-like object we attempt to create the best identifier we can.""" result = "" if hasattr(node, "resource_type"): @@ -103,7 +103,7 @@ def node_to_string(self, node: Any): return result.strip() if result != "" else "" - def process_stack(self): + def process_stack(self) -> List[str]: lines = [] stack = self.stack + [self.node] first = True @@ -122,7 +122,7 @@ def process_stack(self): return lines - def validator_error_message(self, exc: builtins.Exception): + def validator_error_message(self, exc: builtins.Exception) -> str: """Given a dbt.dataclass_schema.ValidationError return the relevant parts as a string. dbt.dataclass_schema.ValidationError is basically a jsonschema.ValidationError) @@ -132,7 +132,7 @@ def validator_error_message(self, exc: builtins.Exception): path = "[%s]" % "][".join(map(repr, exc.relative_path)) return f"at path {path}: {exc.message}" - def __str__(self, prefix: str = "! "): + def __str__(self, prefix: str = "! ") -> str: node_string = "" if self.node is not None: @@ -149,7 +149,7 @@ def __str__(self, prefix: str = "! "): return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]]) - def data(self): + def data(self) -> Dict[str, Any]: result = DbtBaseException.data(self) if self.node is None: return result @@ -236,7 +236,7 @@ class DbtDatabaseError(DbtRuntimeError): CODE = 10003 MESSAGE = "Database Error" - def process_stack(self): + def process_stack(self) -> List[str]: lines = [] if hasattr(self.node, "build_path") and self.node.build_path: @@ -250,7 +250,7 @@ def type(self): class UnexpectedNullError(DbtDatabaseError): - def __init__(self, field_name: str, source): + def __init__(self, field_name: str, source) -> None: self.field_name = field_name self.source = source msg = ( @@ -268,7 +268,7 @@ def __init__(self, cwd: str, cmd: List[str], msg: str = "Error running command") self.cmd = cmd_scrubbed self.args = (cwd, cmd_scrubbed, msg) - def __str__(self): + def __str__(self, prefix: str = "! ") -> str: if len(self.cmd) == 0: return f"{self.msg}: No arguments given" return f'{self.msg}: "{self.cmd[0]}"'