Skip to content

Commit

Permalink
Fix bug in imputation missingness strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonWilde committed Oct 19, 2023
1 parent 82518c4 commit 03b1974
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 42 deletions.
1 change: 0 additions & 1 deletion data/support_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ columns:
x2:
categorical: true
dtype: int64
missingness: drop
x3:
categorical: true
x4:
Expand Down
24 changes: 12 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 10 additions & 26 deletions src/nhssynth/modules/dataloader/metatransformer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import pathlib
import sys
from typing import Any, Callable, Optional, Self, Union
from typing import Any, Optional, Self, Union

import numpy as np
import pandas as pd
from tqdm import tqdm

from nhssynth.modules.dataloader.metadata import MetaData
from nhssynth.modules.dataloader.missingness import (
MISSINGNESS_STRATEGIES,
ImputeMissingnessStrategy,
)
from nhssynth.modules.dataloader.missingness import MISSINGNESS_STRATEGIES


class MetaTransformer:
Expand Down Expand Up @@ -59,26 +56,9 @@ def __init__(
if missingness_strategy == "impute":
assert (
impute_value is not None
), "`impute_value` must be specified when using the imputation missingness strategy"
self._missingness_strategy = self._impute_missingness_strategy_generator(impute_value)
else:
self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy]

def _impute_missingness_strategy_generator(self, impute_value: Any) -> Callable[[], ImputeMissingnessStrategy]:
"""
Create a function to return a new instance of the impute missingness strategy with the given impute value.
Args:
impute_value: The value to use when imputing missing values in the data.
Returns:
A function that returns a new instance of the impute missingness strategy with the given impute value.
"""

def _impute_missingness_strategy() -> ImputeMissingnessStrategy:
return ImputeMissingnessStrategy(impute_value)

return _impute_missingness_strategy
), "`impute_value` of the `MetaTransformer` must be specified (via the --impute flag) when using the imputation missingness strategy"
self._impute_value = impute_value
self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy]

@classmethod
def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:
Expand Down Expand Up @@ -193,7 +173,11 @@ def apply_missingness_strategy(self) -> pd.DataFrame:
working_data = self.typed_dataset.copy()
for column_metadata in self._metadata:
if not column_metadata.missingness_strategy:
column_metadata.missingness_strategy = self._missingness_strategy()
column_metadata.missingness_strategy = (
self._missingness_strategy(self._impute_value)
if hasattr(self, "_impute_value")
else self._missingness_strategy()
)
if not working_data[column_metadata.name].isnull().any():
continue
working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)
Expand Down
14 changes: 11 additions & 3 deletions src/nhssynth/modules/dataloader/missingness.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
import warnings
from abc import ABC, abstractmethod
from typing import Any, Final

Expand Down Expand Up @@ -73,15 +74,22 @@ def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.Data
Returns:
The dataset with missing values in the appropriate column replaced with imputed ones.
"""
if self.impute == "mean":
if (self.impute == "mean" or self.impute == "median") and column_metadata.categorical:
warnings.warn("Cannot impute mean or median for categorical data, using mode instead.")
self.imputation_value = data[column_metadata.name].mode()[0]
elif self.impute == "mean":
self.imputation_value = data[column_metadata.name].mean()
elif self.impute == "median":
self.imputation_value = data[column_metadata.name].median()
elif self.impute == "mode":
self.imputation_value = data[column_metadata.name].mode()[0]
else:
self.imputation_value = self.impute
data[column_metadata.name].fillna(self.imputation_value, inplace=True)
self.imputation_value = column_metadata.dtype.type(self.imputation_value)
try:
data[column_metadata.name].fillna(self.imputation_value, inplace=True)
except AssertionError:
raise ValueError(f"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.")
return data


Expand Down Expand Up @@ -112,7 +120,7 @@ def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.Data


MISSINGNESS_STRATEGIES: Final = {
"none": NullMissingnessStrategy,
# "none": NullMissingnessStrategy,
"impute": ImputeMissingnessStrategy,
"augment": AugmentMissingnessStrategy,
"drop": DropMissingnessStrategy,
Expand Down

0 comments on commit 03b1974

Please sign in to comment.