Skip to content

Commit

Permalink
black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonny Pearson committed Dec 20, 2024
1 parent edf2a23 commit 24c7595
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 55 deletions.
4 changes: 1 addition & 3 deletions src/nhssynth/cli/common_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def get_seed_parser(overrides=False) -> argparse.ArgumentParser:
return parser


COMMON_TITLE: Final = (
"starting any of the following args with `_` defaults to a suffix on DATASET (e.g. `_metadata` -> `<DATASET>_metadata`);\nall filenames are relative to `experiments/<EXPERIMENT_NAME>/` unless otherwise stated"
)
COMMON_TITLE: Final = "starting any of the following args with `_` defaults to a suffix on DATASET (e.g. `_metadata` -> `<DATASET>_metadata`);\nall filenames are relative to `experiments/<EXPERIMENT_NAME>/` unless otherwise stated"


def suffix_parser_generator(name: str, help: str, required: bool = False) -> argparse.ArgumentParser:
Expand Down
23 changes: 10 additions & 13 deletions src/nhssynth/modules/dataloader/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ class ConstraintGraph:
POSITIVITY_TO_OPERATOR: Final = {"positive": ">", "nonnegative": ">=", "negative": "<", "nonpositive": "<="}
BRACKET_TO_OPERATOR: Final = {"[": ">=", "]": "<=", "(": ">", ")": "<"}
OPERATOR_TO_PANDAS: Final = {"<": pd.Series.lt, "<=": pd.Series.le, ">": pd.Series.gt, ">=": pd.Series.ge}

class Constraint:
VALID_OPERATORS: Final = [">", ">=", "<", "<=", "in"]
POSITIVITY_TO_OPERATOR: Final = {"positive": ">", "nonnegative": ">=", "negative": "<", "nonpositive": "<="}
BRACKET_TO_OPERATOR: Final = {"[": ">=", "]": "<=", "(": ">", ")": "<"}
OPERATOR_TO_PANDAS: Final = {"<": pd.Series.lt, "<=": pd.Series.le, ">": pd.Series.gt, ">=": pd.Series.ge}

def __init__(
self,
base: str,
Expand All @@ -41,7 +42,7 @@ def __eq__(self, other) -> bool:
and self.reference == other.reference
and self.reference_is_column == other.reference_is_column
)

def transform(self, df):
# Ensure that the base column exists in the DataFrame
if self.base not in df.columns:
Expand All @@ -50,7 +51,7 @@ def transform(self, df):
# Handle float-based constraints (e.g., columnA > 10)
if not self.reference_is_column:
reference = float(self.reference)
adherence = self.OPERATOR_TO_PANDAS[self.operator](df[self.base], reference)
adherence = self.OPERATOR_TO_PANDAS[self.operator](df[self.base], reference)
else:
# Handle column-to-column constraints (e.g., columnB <= columnC)
reference = df[self.reference]
Expand All @@ -61,16 +62,12 @@ def transform(self, df):

# Optionally calculate and store the difference for rows that don't meet the constraint
# This is useful for identifying the "degree" to which the constraint is violated
# diff = np.abs(df[self.base] - self.reference)
# diff[~adherence] = np.nan # Set diff to NaN where adherence is False
# df[self.base + "_diff"] = diff
# diff = np.abs(df[self.base] - self.reference)
# diff[~adherence] = np.nan # Set diff to NaN where adherence is False
# df[self.base + "_diff"] = diff

return df





class ComboConstraint:
def __init__(self, columns: list[str]):
self.columns = columns
Expand Down Expand Up @@ -265,10 +262,10 @@ def _traverse_longest_path(
ref_is_col, operator = True, ">"
if subgraph.edges[item1, item2]["color"] == "green":
operator += "="
if subgraph.nodes[item1]["color"] == "red": # Note: this breaks if two none col constraints are the same!
if subgraph.nodes[item1]["color"] == "red": # Note: this breaks if two none col constraints are the same!
item1, item2 = item2, item1
ref_is_col, operator = False, operator.replace(">", "<")
if subgraph.nodes[item2]["color"] == "red": # Note: this breaks if two none col constraints are the same!
if subgraph.nodes[item2]["color"] == "red": # Note: this breaks if two none col constraints are the same!
ref_is_col = False
constraint = self.Constraint(item1, operator, item2, reference_is_column=ref_is_col)
if constraint not in constraints:
Expand Down Expand Up @@ -317,7 +314,7 @@ def _output_graphs_html(self, name: str) -> None:
html = net.generate_html(notebook=False)
with open(str(name).replace(".html", "_minimal.html"), "w") as f:
f.write(html)

def __iter__(self):
"""
Make the ConstraintGraph iterable over the minimal_constraints.
Expand Down
8 changes: 6 additions & 2 deletions src/nhssynth/modules/dataloader/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _validate_rounding_scheme(self, data: pd.Series, dtype: np.dtype, dtype_dict
roundable_data = data[data.notna()]
for i in range(np.finfo(dtype).precision):
if (roundable_data.round(i) == roundable_data).all():
return 10**-i
return 10 ** -i
return None

def _validate_categorical(self, data: pd.Series, categorical: Optional[bool] = None) -> bool:
Expand Down Expand Up @@ -321,7 +321,11 @@ def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:
"sdtype": (
"boolean"
if cmd.boolean
else "categorical" if cmd.categorical else "datetime" if cmd.dtype.kind == "M" else "numerical"
else "categorical"
if cmd.categorical
else "datetime"
if cmd.dtype.kind == "M"
else "numerical"
),
}
for cn, cmd in self._metadata.items()
Expand Down
12 changes: 7 additions & 5 deletions src/nhssynth/modules/dataloader/metatransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def _apply_dtype(
dtype = column_metadata.dtype
try:
if dtype.kind == "M":
working_column = pd.to_datetime(working_column, format=column_metadata.datetime_config.get("format"), errors='coerce')
working_column = pd.to_datetime(
working_column, format=column_metadata.datetime_config.get("format"), errors="coerce"
)
if column_metadata.datetime_config.get("floor"):
working_column = working_column.dt.floor(column_metadata.datetime_config.get("floor"))
column_metadata.datetime_config["format"] = column_metadata._infer_datetime_format(working_column)
Expand Down Expand Up @@ -206,14 +208,14 @@ def _get_missingness_carrier(self, column_metadata: MetaData.ColumnMetaData) ->
return self.post_missingness_strategy_dataset[missingness_carrier]
else:
return missingness_carrier

def _get_adherence_constraint(self, df) -> Union[pd.Series, Any]:

adherence_columns = [col for col in df.columns if col.endswith('_adherence')]
constraint_adherence = df[adherence_columns].prod(axis=1).astype(int)
adherence_columns = [col for col in df.columns if col.endswith("_adherence")]
constraint_adherence = df[adherence_columns].prod(axis=1).astype(int)

return constraint_adherence

def transform(self) -> pd.DataFrame:
"""
Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.
Expand Down
12 changes: 10 additions & 2 deletions src/nhssynth/modules/dataloader/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def __init__(self) -> None:
super().__init__()

@abstractmethod
def apply(self, data: pd.DataFrame, missingness_column: Optional[pd.Series], constraint_adherence: Optional[pd.Series]) -> None:
def apply(
self, data: pd.DataFrame, missingness_column: Optional[pd.Series], constraint_adherence: Optional[pd.Series]
) -> None:
"""Apply the transformer to the data."""
pass

Expand All @@ -33,7 +35,13 @@ def __init__(self, wrapped_transformer: ColumnTransformer) -> None:
super().__init__()
self._wrapped_transformer: ColumnTransformer = wrapped_transformer

def apply(self, data: pd.Series, missingness_column: Optional[pd.Series], constraint_adherence: Optional[pd.Series], **kwargs) -> pd.DataFrame:
def apply(
self,
data: pd.Series,
missingness_column: Optional[pd.Series],
constraint_adherence: Optional[pd.Series],
**kwargs
) -> pd.DataFrame:
"""Method for applying the wrapped transformer to the data."""
return self._wrapped_transformer.apply(data, missingness_column, constraint_adherence, **kwargs)

Expand Down
32 changes: 17 additions & 15 deletions src/nhssynth/modules/dataloader/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,34 @@ def __init__(self, drop: Optional[Union[list, str]] = None) -> None:
self._transformer: OneHotEncoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False, drop=self._drop)
self.missing_value: Any = None

def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], missing_value: Optional[Any] = None) -> pd.DataFrame:
def apply(
self, data: pd.Series, constraint_adherence: Optional[pd.Series], missing_value: Optional[Any] = None
) -> pd.DataFrame:
"""
Applies a transformation to the input data using scikit-learn's `OneHotEncoder`'s `fit_transform` method.
This method transforms the input data (`data`) into one-hot encoded format. If a `missing_value` is provided, missing values are replaced
with the specified value before the transformation is applied. The transformation is further filtered based on the provided
`constraint_adherence` Series, which determines which rows are included in the transformation process (only rows where the value in
This method transforms the input data (`data`) into one-hot encoded format. If a `missing_value` is provided, missing values are replaced
with the specified value before the transformation is applied. The transformation is further filtered based on the provided
`constraint_adherence` Series, which determines which rows are included in the transformation process (only rows where the value in
`constraint_adherence` is `1` are retained).
The resulting transformed data includes the one-hot encoded columns for the original data, with the constraint adherence values
The resulting transformed data includes the one-hot encoded columns for the original data, with the constraint adherence values
appended as an additional column. If a column labeled `0` is created (which may happen in certain transformations), it is dropped.
Args:
data (pd.Series): The input column of data to be transformed. The data is expected to be a single column to which the transformation will be applied.
constraint_adherence (Optional[pd.Series]): A Series indicating whether each row should be included in the transformation.
constraint_adherence (Optional[pd.Series]): A Series indicating whether each row should be included in the transformation.
Rows corresponding to `1` will be included, and rows corresponding to `0` will be excluded.
missing_value (Optional[Any]): The value used to replace missing values (`NaN`) in the `data` before applying the transformation.
missing_value (Optional[Any]): The value used to replace missing values (`NaN`) in the `data` before applying the transformation.
This is primarily used for the `AugmentMissingnessStrategy` to ensure that missing values are treated as a specific category.
Returns:
pd.DataFrame: A DataFrame containing the transformed data. The DataFrame consists of:
- One-hot encoded columns based on the original data, with each column corresponding to a unique category.
- The `constraint_adherence` column, indicating whether the row satisfies the user-defined constraints (values of `1` or `0`).
Notes:
- If `missing_value` is provided, missing values in the `data` column are replaced with the specified value before the transformation.
- Only rows where `constraint_adherence == 1` will be included in the transformed data.
Expand All @@ -69,7 +71,7 @@ def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], miss
data = data.fillna(missing_value)
self.missing_value = missing_value
semi_index = data.index

data = data[constraint_adherence == 1]
transformed_data = pd.DataFrame(
self._transformer.fit_transform(data.values.reshape(-1, 1)),
Expand All @@ -79,7 +81,7 @@ def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], miss

if 0 in transformed_data.columns:
transformed_data = transformed_data.drop(columns=[0])

self.new_column_names = transformed_data.columns
return transformed_data

Expand Down
29 changes: 15 additions & 14 deletions src/nhssynth/modules/dataloader/transformers/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,36 @@ def __init__(
self.remove_unused_components = remove_unused_components
self.clip_output = clip_output

def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], missingness_column: Optional[pd.Series] = None) -> pd.DataFrame:
def apply(
self, data: pd.Series, constraint_adherence: Optional[pd.Series], missingness_column: Optional[pd.Series] = None
) -> pd.DataFrame:
"""
Apply the transformation to a given data column using the `BayesianGaussianMixture` model from scikit-learn.
This method transforms the input data (`data`) by fitting a `BayesianGaussianMixture` model to the data, and normalizes the values based on the
This method transforms the input data (`data`) by fitting a `BayesianGaussianMixture` model to the data, and normalizes the values based on the
learned parameters. Additionally, it handles missing data by utilizing the provided `missingness_column` and `constraint_adherence` to determine
which rows should be included in the transformation. The resulting transformed data consists of the normalized values, along with component
which rows should be included in the transformation. The resulting transformed data consists of the normalized values, along with component
probabilities, and a final adherence column indicating whether the data satisfies the constraints.
If the `missingness_column` is provided, missing values are handled by assigning them to a new pseudo-cluster with a mean of 0, ensuring that
If the `missingness_column` is provided, missing values are handled by assigning them to a new pseudo-cluster with a mean of 0, ensuring that
missing data does not affect the transformation process. Missing values are filled with zeros, and the column names are updated accordingly.
Args:
data (pd.Series): The input column of data to be transformed. This column is used to fit the `BayesianGaussianMixture` model.
constraint_adherence (Optional[pd.Series]): A series indicating whether each row satisfies the user-defined constraints. Only rows where
the value in `constraint_adherence` is 1 are included in the transformation process.
missingness_column (Optional[pd.Series]): A series indicating missing values. If provided, missing values will be assigned to a pseudo-cluster
missingness_column (Optional[pd.Series]): A series indicating missing values. If provided, missing values will be assigned to a pseudo-cluster
with mean 0. The missing values are handled separately to ensure that they don't interfere with the transformation.
Returns:
pd.DataFrame: A DataFrame containing the transformed data with the following columns:
- `<original_column_name>_normalised`: The normalized version of the input data.
- `<original_column_name>_c1`, ..., `<original_column_name>_cn`: Columns representing the component probabilities, where `n` is the
- `<original_column_name>_c1`, ..., `<original_column_name>_cn`: Columns representing the component probabilities, where `n` is the
number of components in the `BayesianGaussianMixture` model.
- The `constraint_adherence` column, which contains 1s and 0s indicating whether each row adheres to the user-defined constraints.
Notes:
- The method uses the `fit` and `predict_proba` methods of `BayesianGaussianMixture` to fit the model and calculate component probabilities.
- If the `missingness_column` is provided, rows with missing values will be handled separately by assigning them to a pseudo-cluster with mean 0.
Expand Down Expand Up @@ -120,8 +122,7 @@ def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], miss
normalised = np.clip(normalised, -1.0, 1.0)
print(normalised)
components = np.eye(self._n_components, dtype=int)[components]



transformed_data = pd.DataFrame(
np.hstack([normalised.reshape(-1, 1), components]),
index=index,
Expand All @@ -130,14 +131,14 @@ def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], miss
)
print(transformed_data)
# EXPERIMENTAL feature, removing components from the column matrix that have no data assigned to them
'''if self.remove_unused_components:
"""if self.remove_unused_components:
nunique = transformed_data.iloc[:, 1:].nunique(dropna=False)
unused_components = nunique[nunique == 1].index
unused_component_idx = [transformed_data.columns.get_loc(col_name) - 1 for col_name in unused_components]
self.means = np.delete(self.means, unused_component_idx)
self.stds = np.delete(self.stds, unused_component_idx)
transformed_data.drop(unused_components, axis=1, inplace=True)
'''
"""

transformed_data = pd.concat([transformed_data.reindex(semi_index).fillna(0.0), constraint_adherence], axis=1)

Expand All @@ -146,7 +147,7 @@ def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], miss

if 0 in transformed_data.columns:
transformed_data = transformed_data.drop(columns=[0])

self.new_column_names = transformed_data.columns
return transformed_data.astype(
{col_name: int for col_name in transformed_data.columns if re.search(r"_c\d+", col_name)}
Expand Down
8 changes: 7 additions & 1 deletion src/nhssynth/modules/dataloader/transformers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ class DatetimeTransformer(TransformerWrapper):
def __init__(self, transformer: ColumnTransformer) -> None:
super().__init__(transformer)

def apply(self, data: pd.Series, constraint_adherence: Optional[pd.Series], missingness_column: Optional[pd.Series] = None, **kwargs) -> pd.DataFrame:
def apply(
self,
data: pd.Series,
constraint_adherence: Optional[pd.Series],
missingness_column: Optional[pd.Series] = None,
**kwargs
) -> pd.DataFrame:
"""
Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch.
The float value of `pd.NaT` under the operation above is then replaced with `np.nan` to ensure missing values are represented correctly.
Expand Down

0 comments on commit 24c7595

Please sign in to comment.