Skip to content

Commit

Permalink
Merge pull request astropy#17546 from eerovaher/config-annotations
Browse files Browse the repository at this point in the history
Add more type annotations to `config`
  • Loading branch information
neutrinoceros authored Dec 18, 2024
2 parents 4e10fa0 + accdd0b commit 45ff837
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
23 changes: 15 additions & 8 deletions astropy/config/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
found at https://configobj.readthedocs.io .
"""

from __future__ import annotations

import contextlib
import importlib
import io
Expand All @@ -19,6 +21,7 @@
from inspect import getdoc
from pathlib import Path
from textwrap import TextWrapper
from typing import TYPE_CHECKING
from warnings import warn

from astropy.extern.configobj import configobj, validate
Expand All @@ -27,6 +30,10 @@

from .paths import get_config_dir_path

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Final

__all__ = (
"ConfigItem",
"ConfigNamespace",
Expand Down Expand Up @@ -84,7 +91,7 @@ class Conf(_config.ConfigNamespace):
conf = Conf()
"""

def __iter__(self):
def __iter__(self) -> Generator[str, None, None]:
for key, val in self.__class__.__dict__.items():
if isinstance(val, ConfigItem):
yield key
Expand All @@ -100,19 +107,19 @@ def __str__(self):
keys = __iter__
"""Iterate over configuration item names."""

def values(self):
def values(self) -> Generator[ConfigItem, None, None]:
"""Iterate over configuration item values."""
for val in self.__class__.__dict__.values():
if isinstance(val, ConfigItem):
yield val

def items(self):
def items(self) -> Generator[tuple[str, ConfigItem], None, None]:
"""Iterate over configuration item ``(name, value)`` pairs."""
for key, val in self.__class__.__dict__.items():
if isinstance(val, ConfigItem):
yield key, val

def help(self, name=None):
def help(self, name: str | None = None) -> None:
"""Print info about configuration items.
Parameters
Expand Down Expand Up @@ -185,7 +192,7 @@ def reload(self, attr=None):
for item in self.values():
item.reload()

def reset(self, attr=None):
def reset(self, attr: str | None = None) -> None:
"""
Reset a configuration item to its default.
Expand Down Expand Up @@ -426,13 +433,13 @@ def reload(self):
baseobj[self.name] = newobj[self.name]
return baseobj.get(self.name)

def __repr__(self):
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__}: name={self.name!r} value={self()!r} at"
f" 0x{id(self):x}>"
)

def __str__(self):
def __str__(self) -> str:
return "\n".join(
(
f"{self.__class__.__name__}: {self.name}",
Expand Down Expand Up @@ -529,7 +536,7 @@ def _validate_val(self, val):

# this dictionary stores the primary copy of the ConfigObj's for each
# root package
_cfgobjs = {}
_cfgobjs: Final[dict[str, configobj.ConfigObj]] = {}


def get_config_filename(packageormod=None, rootname=None):
Expand Down
46 changes: 35 additions & 11 deletions astropy/config/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
data/cache files used by Astropy should be placed.
"""

from __future__ import annotations

import os
import shutil
import sys
from functools import wraps
from inspect import cleandoc
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType
from typing import Literal, ParamSpec

P = ParamSpec("P")

__all__ = [
"get_cache_dir",
Expand All @@ -20,7 +30,9 @@
]


def _get_dir_path(rootname: str, cls: type, fallback: str) -> Path:
def _get_dir_path(
rootname: str, cls: type[_SetTempPath], fallback: Literal["cache", "config"]
) -> Path:
# If using set_temp_x, that overrides all
if (xch := cls._temp_path) is not None:
path = xch / rootname
Expand Down Expand Up @@ -131,36 +143,43 @@ def get_cache_dir(rootname: str = "astropy") -> str:


class _SetTempPath:
_temp_path = None
_default_path_getter = None
_temp_path: Path | None = None
_default_path_getter: Callable[[str], str]

def __init__(self, path=None, delete=False):
def __init__(
self, path: os.PathLike[str] | str | None = None, delete: bool = False
) -> None:
if path is not None:
path = Path(path).resolve()

self._path = path
self._delete = delete
self._prev_path = self.__class__._temp_path

def __enter__(self):
def __enter__(self) -> str:
self.__class__._temp_path = self._path
try:
return self._default_path_getter("astropy")
except Exception:
self.__class__._temp_path = self._prev_path
raise

def __exit__(self, *args):
def __exit__(
self,
type: type[BaseException] | None,
value: BaseException | None,
tb: TracebackType | None,
) -> None:
self.__class__._temp_path = self._prev_path

if self._delete and self._path is not None:
shutil.rmtree(self._path)

def __call__(self, func):
def __call__(self, func: Callable[P, object]) -> Callable[P, None]:
"""Implements use as a decorator."""

@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
with self:
func(*args, **kwargs)

Expand Down Expand Up @@ -195,7 +214,7 @@ class set_temp_config(_SetTempPath):

_default_path_getter = staticmethod(get_config_dir)

def __enter__(self):
def __enter__(self) -> str:
# Special case for the config case, where we need to reset all the
# cached config objects. We do keep the cache, since some of it
# may have been set programmatically rather than be stored in the
Expand All @@ -206,13 +225,18 @@ def __enter__(self):
_cfgobjs.clear()
return super().__enter__()

def __exit__(self, *args):
def __exit__(
self,
type: type[BaseException] | None,
value: BaseException | None,
tb: TracebackType | None,
) -> None:
from .configuration import _cfgobjs

_cfgobjs.clear()
_cfgobjs.update(self._cfgobjs_copy)
del self._cfgobjs_copy
super().__exit__(*args)
super().__exit__(type, value, tb)


class set_temp_cache(_SetTempPath):
Expand Down

0 comments on commit 45ff837

Please sign in to comment.