Skip to content

Commit

Permalink
Merge pull request #44 from apollo13/pydantic2-fixes
Browse files Browse the repository at this point in the history
Use contextvars instead of patching pydantic internals.
  • Loading branch information
dchukhin authored Oct 8, 2024
2 parents 850e966 + 1e1f7f8 commit 6d9a49f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 58 deletions.
101 changes: 49 additions & 52 deletions goodconf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Transparently load variables from environment or JSON/YAML file.
"""

# Note: the following line is included to ensure Python3.9 compatibility.
from __future__ import annotations

Expand All @@ -9,35 +10,28 @@
import logging
import os
import sys
from functools import partial
from io import StringIO
from types import GenericAlias
from typing import (
Any,
List,
Optional,
Tuple,
Type,
cast,
get_origin,
get_args,
Union,
)
from typing import TYPE_CHECKING, cast, get_args

from pydantic import PrivateAttr
from pydantic._internal._config import config_keys
from pydantic.fields import ( # noqa
Field,
FieldInfo,
ModelPrivateAttr,
PydanticUndefined,
)
from pydantic.fields import Field, PydanticUndefined
from pydantic.main import _object_setattr
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)

if TYPE_CHECKING:
from typing import Any

from pydantic.fields import FieldInfo


__all__ = ["GoodConf", "GoodConfConfigDict", "Field"]

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -156,16 +150,12 @@ def get_field_value(
def __call__(self) -> dict[str, Any]:
settings = cast(GoodConf, self.settings_cls)
selected_config_file = None
# already loaded from a file
if not isinstance(settings._config_file, ModelPrivateAttr):
return {}
elif (
settings.model_config.get("file_env_var")
and settings.model_config["file_env_var"] in os.environ
if cfg_file := self.current_state.get("_config_file"):
selected_config_file = cfg_file
elif (file_env_var := settings.model_config.get("file_env_var")) and (
cfg_file := os.environ.get(file_env_var)
):
selected_config_file = _find_file(
os.environ[settings.model_config["file_env_var"]]
)
selected_config_file = _find_file(cfg_file)
else:
for filename in settings.model_config.get("default_files") or []:
selected_config_file = _find_file(filename, require=False)
Expand All @@ -174,38 +164,31 @@ def __call__(self) -> dict[str, Any]:
if selected_config_file:
values = _load_config(selected_config_file)
log.info("Loading config from %s", selected_config_file)
settings._config_file = selected_config_file
else:
values = {}
log.info("No config file specified. Loading with environment variables.")
settings._config_file = None
return values

def __repr__(self) -> str:
return "FileConfigSettingsSource()"


class GoodConf(BaseSettings):
_config_file: str = PrivateAttr(None)

def __init__(self, load: bool = False, config_file: str | None = None, **kwargs):
def __init__(
self, load: bool = False, config_file: str | None = None, **kwargs
) -> None:
"""
:param load: load config file on instantiation [default: False].
A docstring defined on the class should be a plain-text description
used as a header when generating a configuration file.
"""
# At this point __pydantic_private__ is None, so setting self.config_file
# raises an error. To avoid this error, explicitly set
# __pydantic_private__ to {} prior to setting self._config_file.
_object_setattr(self, "__pydantic_private__", {})
self._config_file = config_file

# Emulate Pydantic behavior, load immediately
if kwargs:
return super().__init__(**kwargs)
elif load:
return self.load()
if kwargs or load: # Emulate Pydantic behavior, load immediately
self._load(_init_config_file=config_file, **kwargs)
elif config_file:
_object_setattr(
self, "_load", partial(self._load, _init_config_file=config_file)
)

@classmethod
def settings_customise_sources(
Expand All @@ -227,16 +210,30 @@ def settings_customise_sources(

model_config = GoodConfConfigDict()

def _settings_build_values(
self,
init_kwargs: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
state = super()._settings_build_values(
init_kwargs,
**kwargs,
)
state.pop("_config_file", None)
return state

def _load(
self,
_config_file: str | None = None,
_init_config_file: str | None = None,
**kwargs,
):
if config_file := _config_file or _init_config_file:
kwargs["_config_file"] = config_file
super().__init__(**kwargs)

def load(self, filename: str | None = None) -> None:
"""Find config file and set values"""
if filename:
values = _load_config(filename)
log.info("Loading config from %s", filename)
else:
values = {}
super().__init__(**values)
if filename:
_object_setattr(self, "_config_file", filename)
self._load(_config_file=filename)

@classmethod
def get_initial(cls, **override) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"pydantic>=2.0",
"pydantic-settings>=2.0",
"pydantic>=2.7",
"pydantic-settings>=2.4",
]

[project.optional-dependencies]
Expand Down
19 changes: 15 additions & 4 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class G(GoodConf):
g = G()
g.load()
mocked_load_config.assert_called_once_with(str(path))
assert g._config_file == str(path)


def test_conflict(tmpdir):
Expand All @@ -39,14 +38,15 @@ class G(GoodConf):

def test_all_env_vars(mocker):
mocked_set_values = mocker.patch("goodconf.BaseSettings.__init__")
mocked_load_config = mocker.patch("goodconf._load_config")

class G(GoodConf):
pass

g = G()
g.load()
mocked_set_values.assert_called_once_with()
assert g._config_file is None
mocked_load_config.assert_not_called()


def test_provided_file(mocker, tmpdir):
Expand All @@ -60,7 +60,19 @@ class G(GoodConf):
g = G()
g.load(str(path))
mocked_load_config.assert_called_once_with(str(path))
assert g._config_file == str(path)


def test_provided_file_from_init(mocker, tmpdir):
mocked_load_config = mocker.patch("goodconf._load_config")
path = tmpdir.join("myapp.json")
path.write("")

class G(GoodConf):
pass

g = G(config_file=str(path))
g.load()
mocked_load_config.assert_called_once_with(str(path))


def test_default_files(mocker, tmpdir):
Expand All @@ -75,4 +87,3 @@ class G(GoodConf):
g = G()
g.load()
mocked_load_config.assert_called_once_with(str(path))
assert g._config_file == str(path)

0 comments on commit 6d9a49f

Please sign in to comment.