diff --git a/data/support_metadata.yaml b/data/support_metadata.yaml index 115156af..231c795f 100644 --- a/data/support_metadata.yaml +++ b/data/support_metadata.yaml @@ -9,7 +9,6 @@ columns: x2: categorical: true dtype: int64 - missingness: drop x3: categorical: true x4: diff --git a/poetry.lock b/poetry.lock index 067274f4..2a941439 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1122,13 +1122,13 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.38" +version = "3.1.40" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.38-py3-none-any.whl", hash = "sha256:9e98b672ffcb081c2c8d5aa630d4251544fb040fb158863054242f24a2a2ba30"}, - {file = "GitPython-3.1.38.tar.gz", hash = "sha256:4d683e8957c8998b58ddb937e3e6cd167215a180e1ffd4da769ab81c620a89fe"}, + {file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"}, + {file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"}, ] [package.dependencies] @@ -2540,21 +2540,21 @@ files = [ [[package]] name = "networkx" -version = "3.1" +version = "3.2" description = "Python package for creating and manipulating graphs and networks" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, - {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, + {file = "networkx-3.2-py3-none-any.whl", hash = "sha256:8b25f564bd28f94ac821c58b04ae1a3109e73b001a7d476e4bb0d00d63706bf8"}, + {file = "networkx-3.2.tar.gz", hash = "sha256:bda29edf392d9bfa5602034c767d28549214ec45f620081f0b74dc036a1fbbc1"}, ] [package.extras] -default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] -developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] -doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] -test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "nodeenv" diff --git a/src/nhssynth/modules/dataloader/metatransformer.py b/src/nhssynth/modules/dataloader/metatransformer.py index 355e1d2b..9377ff3c 100644 --- a/src/nhssynth/modules/dataloader/metatransformer.py +++ b/src/nhssynth/modules/dataloader/metatransformer.py @@ -1,16 +1,13 @@ import pathlib import sys -from typing import Any, Callable, Optional, Self, Union +from typing import Any, Optional, Self, Union import numpy as np import pandas as pd from tqdm import tqdm from nhssynth.modules.dataloader.metadata import MetaData -from nhssynth.modules.dataloader.missingness import ( - MISSINGNESS_STRATEGIES, - ImputeMissingnessStrategy, -) +from nhssynth.modules.dataloader.missingness import MISSINGNESS_STRATEGIES class MetaTransformer: @@ -59,26 +56,9 @@ def __init__( if missingness_strategy == "impute": assert ( impute_value is not None - ), "`impute_value` must be specified when using the imputation missingness strategy" - self._missingness_strategy = self._impute_missingness_strategy_generator(impute_value) - else: - self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy] - - def _impute_missingness_strategy_generator(self, impute_value: Any) -> Callable[[], ImputeMissingnessStrategy]: - """ - Create a function to return a new instance of the impute missingness strategy with the given impute value. - - Args: - impute_value: The value to use when imputing missing values in the data. - - Returns: - A function that returns a new instance of the impute missingness strategy with the given impute value. - """ - - def _impute_missingness_strategy() -> ImputeMissingnessStrategy: - return ImputeMissingnessStrategy(impute_value) - - return _impute_missingness_strategy + ), "`impute_value` of the `MetaTransformer` must be specified (via the --impute flag) when using the imputation missingness strategy" + self._impute_value = impute_value + self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy] @classmethod def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self: @@ -193,7 +173,11 @@ def apply_missingness_strategy(self) -> pd.DataFrame: working_data = self.typed_dataset.copy() for column_metadata in self._metadata: if not column_metadata.missingness_strategy: - column_metadata.missingness_strategy = self._missingness_strategy() + column_metadata.missingness_strategy = ( + self._missingness_strategy(self._impute_value) + if hasattr(self, "_impute_value") + else self._missingness_strategy() + ) if not working_data[column_metadata.name].isnull().any(): continue working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata) diff --git a/src/nhssynth/modules/dataloader/missingness.py b/src/nhssynth/modules/dataloader/missingness.py index 34e65480..be507f24 100644 --- a/src/nhssynth/modules/dataloader/missingness.py +++ b/src/nhssynth/modules/dataloader/missingness.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +import warnings from abc import ABC, abstractmethod from typing import Any, Final @@ -73,7 +74,10 @@ def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.Data Returns: The dataset with missing values in the appropriate column replaced with imputed ones. """ - if self.impute == "mean": + if (self.impute == "mean" or self.impute == "median") and column_metadata.categorical: + warnings.warn("Cannot impute mean or median for categorical data, using mode instead.") + self.imputation_value = data[column_metadata.name].mode()[0] + elif self.impute == "mean": self.imputation_value = data[column_metadata.name].mean() elif self.impute == "median": self.imputation_value = data[column_metadata.name].median() @@ -81,7 +85,11 @@ def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.Data self.imputation_value = data[column_metadata.name].mode()[0] else: self.imputation_value = self.impute - data[column_metadata.name].fillna(self.imputation_value, inplace=True) + self.imputation_value = column_metadata.dtype.type(self.imputation_value) + try: + data[column_metadata.name].fillna(self.imputation_value, inplace=True) + except AssertionError: + raise ValueError(f"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.") return data @@ -112,7 +120,7 @@ def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.Data MISSINGNESS_STRATEGIES: Final = { - "none": NullMissingnessStrategy, + # "none": NullMissingnessStrategy, "impute": ImputeMissingnessStrategy, "augment": AugmentMissingnessStrategy, "drop": DropMissingnessStrategy,