From aca281dabaf7a31892623d4e096d97d8c0cabc42 Mon Sep 17 00:00:00 2001 From: <> Date: Wed, 28 Feb 2024 13:03:20 +0000 Subject: [PATCH] Deployed a6187b8 with MkDocs version: 1.5.3 --- reference/cli/common_arguments/index.html | 170 +++---- reference/cli/config/index.html | 438 ++++++++-------- reference/cli/model_arguments/index.html | 334 ++++++------ reference/cli/module_arguments/index.html | 476 +++++++++--------- reference/cli/module_setup/index.html | 228 ++++----- reference/common/common/index.html | 26 +- reference/common/dicts/index.html | 160 +++--- reference/common/io/index.html | 218 ++++---- reference/common/strings/index.html | 90 ++-- .../modules/dataloader/metadata/index.html | 218 ++++---- search/search_index.json | 2 +- sitemap.xml.gz | Bin 127 -> 127 bytes 12 files changed, 1178 insertions(+), 1182 deletions(-) diff --git a/reference/cli/common_arguments/index.html b/reference/cli/common_arguments/index.html index 0fc5aef9..b4265abf 100644 --- a/reference/cli/common_arguments/index.html +++ b/reference/cli/common_arguments/index.html @@ -2561,8 +2561,7 @@

Source code in src/nhssynth/cli/common_arguments.py -
12
-13
+            
13
 14
 15
 16
@@ -2594,40 +2593,41 @@ 

42 43 44 -45

def get_core_parser(overrides=False) -> argparse.ArgumentParser:
-    """
-    Create the core common parser group applied to all modules (and the `pipeline` and `config` options).
-    Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.
-
-    Args:
-        overrides: whether the arguments declared within are required or not.
-
-    Returns:
-        The parser with the group containing the core arguments attached.
-    """
-    """"""
-    core = argparse.ArgumentParser(add_help=False)
-    core_grp = core.add_argument_group(title="options")
-    core_grp.add_argument(
-        "-d",
-        "--dataset",
-        required=(not overrides),
-        type=str,
-        help="the name of the dataset to experiment with, should be present in `<DATA_DIR>`",
-    )
-    core_grp.add_argument(
-        "-e",
-        "--experiment-name",
-        type=str,
-        default=TIME,
-        help="name the experiment run to affect logging, config, and default-behaviour i/o",
-    )
-    core_grp.add_argument(
-        "--save-config",
-        action="store_true",
-        help="save the config provided via the cli, this is a recommended option for reproducibility",
-    )
-    return core
+45
+46
def get_core_parser(overrides=False) -> argparse.ArgumentParser:
+    """
+    Create the core common parser group applied to all modules (and the `pipeline` and `config` options).
+    Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.
+
+    Args:
+        overrides: whether the arguments declared within are required or not.
+
+    Returns:
+        The parser with the group containing the core arguments attached.
+    """
+    """"""
+    core = argparse.ArgumentParser(add_help=False)
+    core_grp = core.add_argument_group(title="options")
+    core_grp.add_argument(
+        "-d",
+        "--dataset",
+        required=(not overrides),
+        type=str,
+        help="the name of the dataset to experiment with, should be present in `<DATA_DIR>`",
+    )
+    core_grp.add_argument(
+        "-e",
+        "--experiment-name",
+        type=str,
+        default=TIME,
+        help="name the experiment run to affect logging, config, and default-behaviour i/o",
+    )
+    core_grp.add_argument(
+        "--save-config",
+        action="store_true",
+        help="save the config provided via the cli, this is a recommended option for reproducibility",
+    )
+    return core
 
@@ -2705,8 +2705,7 @@

Source code in src/nhssynth/cli/common_arguments.py -
48
-49
+            
49
 50
 51
 52
@@ -2724,26 +2723,27 @@ 

64 65 66 -67

def get_seed_parser(overrides=False) -> argparse.ArgumentParser:
-    """
-    Create the common parser group for the seed.
-    NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.
-
-    Args:
-        overrides: whether the arguments declared within are required or not.
-
-    Returns:
-        The parser with the group containing the seed argument attached.
-    """
-    parser = argparse.ArgumentParser(add_help=False)
-    parser_grp = parser.add_argument_group(title="options")
-    parser_grp.add_argument(
-        "-s",
-        "--seed",
-        type=int,
-        help="specify a seed for reproducibility, this is a recommended option for reproducibility",
-    )
-    return parser
+67
+68
def get_seed_parser(overrides=False) -> argparse.ArgumentParser:
+    """
+    Create the common parser group for the seed.
+    NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.
+
+    Args:
+        overrides: whether the arguments declared within are required or not.
+
+    Returns:
+        The parser with the group containing the seed argument attached.
+    """
+    parser = argparse.ArgumentParser(add_help=False)
+    parser_grp = parser.add_argument_group(title="options")
+    parser_grp.add_argument(
+        "-s",
+        "--seed",
+        type=int,
+        help="specify a seed for reproducibility, this is a recommended option for reproducibility",
+    )
+    return parser
 
@@ -2826,10 +2826,7 @@

Source code in src/nhssynth/cli/common_arguments.py -
73
-74
-75
-76
+            
76
 77
 78
 79
@@ -2848,29 +2845,32 @@ 

92 93 94 -95

def suffix_parser_generator(name: str, help: str, required: bool = False) -> argparse.ArgumentParser:
-    """Generator function for creating parsers following a common template.
-    These parsers are all suffixes to the --dataset / -d / DATASET argument, see `COMMON_TITLE`.
-
-    Args:
-        name: the name / label of the argument to add to the CLI options.
-        help: the help message when the CLI is run with --help / -h.
-        required: whether the argument must be provided or not.
-    """
-
-    def get_parser(overrides: bool = False) -> argparse.ArgumentParser:
-        parser = argparse.ArgumentParser(add_help=False)
-        parser_grp = parser.add_argument_group(title=COMMON_TITLE)
-        parser_grp.add_argument(
-            f"--{name.replace('_', '-')}",
-            required=required and not overrides,
-            type=str,
-            default=f"_{name}",
-            help=help,
-        )
-        return parser
-
-    return get_parser
+95
+96
+97
+98
def suffix_parser_generator(name: str, help: str, required: bool = False) -> argparse.ArgumentParser:
+    """Generator function for creating parsers following a common template.
+    These parsers are all suffixes to the --dataset / -d / DATASET argument, see `COMMON_TITLE`.
+
+    Args:
+        name: the name / label of the argument to add to the CLI options.
+        help: the help message when the CLI is run with --help / -h.
+        required: whether the argument must be provided or not.
+    """
+
+    def get_parser(overrides: bool = False) -> argparse.ArgumentParser:
+        parser = argparse.ArgumentParser(add_help=False)
+        parser_grp = parser.add_argument_group(title=COMMON_TITLE)
+        parser_grp.add_argument(
+            f"--{name.replace('_', '-')}",
+            required=required and not overrides,
+            type=str,
+            default=f"_{name}",
+            help=help,
+        )
+        return parser
+
+    return get_parser
 
diff --git a/reference/cli/config/index.html b/reference/cli/config/index.html index 7866705c..cf8a993c 100644 --- a/reference/cli/config/index.html +++ b/reference/cli/config/index.html @@ -2615,8 +2615,7 @@

Source code in src/nhssynth/cli/config.py -
151
-152
+            
152
 153
 154
 155
@@ -2676,68 +2675,69 @@ 

209 210 211 -212

def assemble_config(
-    args: argparse.Namespace,
-    all_subparsers: dict[str, argparse.ArgumentParser],
-) -> dict[str, Any]:
-    """
-    Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.
-
-    Args:
-        args: A namespace object containing all parsed command-line arguments.
-        all_subparsers: A dictionary mapping module names to subparser objects.
-
-    Returns:
-        A dictionary containing configuration information extracted from `args` in a module-wise nested format that is YAML-friendly.
-
-    Raises:
-        ValueError: If a module specified in `args.modules_to_run` is not in `all_subparsers`.
-    """
-    args_dict = vars(args)
-
-    # Filter out the keys that are not relevant to the config file
-    args_dict = filter_dict(
-        args_dict, {"func", "experiment_name", "save_config", "save_config_path", "module_handover"}
-    )
-    for k in args_dict.copy().keys():
-        # Remove empty metric lists from the config
-        if "_metrics" in k and not args_dict[k]:
-            args_dict.pop(k)
-
-    modules_to_run = args_dict.pop("modules_to_run")
-    if len(modules_to_run) == 1:
-        run_type = modules_to_run[0]
-    elif modules_to_run == PIPELINE:
-        run_type = "pipeline"
-    else:
-        raise ValueError(f"Invalid value for `modules_to_run`: {modules_to_run}")
-
-    # Generate a dictionary containing each module's name from the run, with all of its possible corresponding config args
-    module_args = {
-        module_name: [action.dest for action in all_subparsers[module_name]._actions if action.dest != "help"]
-        for module_name in modules_to_run
-    }
-
-    # Use the flat namespace to populate a nested (by module) dictionary of config args and values
-    out_dict = {}
-    for module_name in modules_to_run:
-        for k in args_dict.copy().keys():
-            # We want to keep dataset, experiment_name, seed and save_config at the top-level as they are core args
-            if k in module_args[module_name] and k not in {
-                "version",
-                "dataset",
-                "experiment_name",
-                "seed",
-                "save_config",
-            }:
-                if module_name not in out_dict:
-                    out_dict[module_name] = {}
-                v = args_dict.pop(k)
-                if v is not None:
-                    out_dict[module_name][k] = v
-
-    # Assemble the final dictionary in YAML-compliant form
-    return {**({"run_type": run_type} if run_type else {}), **args_dict, **out_dict}
+212
+213
def assemble_config(
+    args: argparse.Namespace,
+    all_subparsers: dict[str, argparse.ArgumentParser],
+) -> dict[str, Any]:
+    """
+    Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.
+
+    Args:
+        args: A namespace object containing all parsed command-line arguments.
+        all_subparsers: A dictionary mapping module names to subparser objects.
+
+    Returns:
+        A dictionary containing configuration information extracted from `args` in a module-wise nested format that is YAML-friendly.
+
+    Raises:
+        ValueError: If a module specified in `args.modules_to_run` is not in `all_subparsers`.
+    """
+    args_dict = vars(args)
+
+    # Filter out the keys that are not relevant to the config file
+    args_dict = filter_dict(
+        args_dict, {"func", "experiment_name", "save_config", "save_config_path", "module_handover"}
+    )
+    for k in args_dict.copy().keys():
+        # Remove empty metric lists from the config
+        if "_metrics" in k and not args_dict[k]:
+            args_dict.pop(k)
+
+    modules_to_run = args_dict.pop("modules_to_run")
+    if len(modules_to_run) == 1:
+        run_type = modules_to_run[0]
+    elif modules_to_run == PIPELINE:
+        run_type = "pipeline"
+    else:
+        raise ValueError(f"Invalid value for `modules_to_run`: {modules_to_run}")
+
+    # Generate a dictionary containing each module's name from the run, with all of its possible corresponding config args
+    module_args = {
+        module_name: [action.dest for action in all_subparsers[module_name]._actions if action.dest != "help"]
+        for module_name in modules_to_run
+    }
+
+    # Use the flat namespace to populate a nested (by module) dictionary of config args and values
+    out_dict = {}
+    for module_name in modules_to_run:
+        for k in args_dict.copy().keys():
+            # We want to keep dataset, experiment_name, seed and save_config at the top-level as they are core args
+            if k in module_args[module_name] and k not in {
+                "version",
+                "dataset",
+                "experiment_name",
+                "seed",
+                "save_config",
+            }:
+                if module_name not in out_dict:
+                    out_dict[module_name] = {}
+                v = args_dict.pop(k)
+                if v is not None:
+                    out_dict[module_name][k] = v
+
+    # Assemble the final dictionary in YAML-compliant form
+    return {**({"run_type": run_type} if run_type else {}), **args_dict, **out_dict}
 
@@ -2831,8 +2831,7 @@

Source code in src/nhssynth/cli/config.py -
13
-14
+            
14
 15
 16
 17
@@ -2856,32 +2855,33 @@ 

35 36 37 -38

def get_default_and_required_args(
-    top_parser: argparse.ArgumentParser,
-    module_parsers: dict[str, argparse.ArgumentParser],
-) -> tuple[dict[str, Any], list[str]]:
-    """
-    Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.
-
-    Args:
-        top_parser: The top-level parser (contains common arguments).
-        module_parsers: The dict of module-level parsers mapped to their names.
-
-    Returns:
-        A tuple containing two elements:
-            - A dictionary containing all arguments and their default values.
-            - A list of key-value-pairs of the required arguments and their associated module.
-    """
-    all_actions = {"top-level": top_parser._actions} | {m: p._actions for m, p in module_parsers.items()}
-    defaults = {}
-    required_args = []
-    for module, actions in all_actions.items():
-        for action in actions:
-            if action.dest not in ["help", "==SUPPRESS=="]:
-                defaults[action.dest] = action.default
-                if action.required:
-                    required_args.append({"arg": action.dest, "module": module})
-    return defaults, required_args
+38
+39
def get_default_and_required_args(
+    top_parser: argparse.ArgumentParser,
+    module_parsers: dict[str, argparse.ArgumentParser],
+) -> tuple[dict[str, Any], list[str]]:
+    """
+    Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.
+
+    Args:
+        top_parser: The top-level parser (contains common arguments).
+        module_parsers: The dict of module-level parsers mapped to their names.
+
+    Returns:
+        A tuple containing two elements:
+            - A dictionary containing all arguments and their default values.
+            - A list of key-value-pairs of the required arguments and their associated module.
+    """
+    all_actions = {"top-level": top_parser._actions} | {m: p._actions for m, p in module_parsers.items()}
+    defaults = {}
+    required_args = []
+    for module, actions in all_actions.items():
+        for action in actions:
+            if action.dest not in ["help", "==SUPPRESS=="]:
+                defaults[action.dest] = action.default
+                if action.required:
+                    required_args.append({"arg": action.dest, "module": module})
+    return defaults, required_args
 
@@ -2959,8 +2959,7 @@

Source code in src/nhssynth/cli/config.py -
135
-136
+            
136
 137
 138
 139
@@ -2972,20 +2971,21 @@ 

145 146 147 -148

def get_modules_to_run(executor: Callable) -> list[str]:
-    """
-    Get the list of modules to run from the passed executor function.
-
-    Args:
-        executor: The executor function to run.
-
-    Returns:
-        A list of module names to run.
-    """
-    if executor == run_pipeline:
-        return PIPELINE
-    else:
-        return [get_key_by_value({mn: mc.func for mn, mc in MODULE_MAP.items()}, executor)]
+148
+149
def get_modules_to_run(executor: Callable) -> list[str]:
+    """
+    Get the list of modules to run from the passed executor function.
+
+    Args:
+        executor: The executor function to run.
+
+    Returns:
+        A list of module names to run.
+    """
+    if executor == run_pipeline:
+        return PIPELINE
+    else:
+        return [get_key_by_value({mn: mc.func for mn, mc in MODULE_MAP.items()}, executor)]
 
@@ -3123,8 +3123,7 @@

Source code in src/nhssynth/cli/config.py -
 41
- 42
+            
 42
  43
  44
  45
@@ -3214,98 +3213,99 @@ 

129 130 131 -132

def read_config(
-    args: argparse.Namespace,
-    parser: argparse.ArgumentParser,
-    all_subparsers: dict[str, argparse.ArgumentParser],
-) -> argparse.Namespace:
-    """
-    Hierarchically assembles a config `argparse.Namespace` object for the inferred modules to run and execute, given a file.
-
-    1. Load the YAML file containing the config to read from
-    2. Check a valid `run_type` is specified or infer it and determine the list of `modules_to_run`
-    3. Establish the appropriate default configuration set of arguments from the `parser` and `all_subparsers` for the determined `modules_to_run`
-    4. Overwrite these with the specified (sub)set of config in the YAML file
-    5. Overwrite again with passed command-line `args` (these are considered 'overrides')
-    6. Run the appropriate module(s) or pipeline with the resulting configuration `Namespace` object
-
-    Args:
-        args: Namespace object containing arguments from the command line
-        parser: top-level `ArgumentParser` object containing common arguments
-        all_subparsers: dictionary of `ArgumentParser` objects, one for each module
-
-    Returns:
-        A Namespace object containing the assembled configuration settings
-
-    Raises:
-        AssertionError: if any required arguments are missing from the configuration file / overrides
-    """
-    # Open the passed yaml file and load into a dictionary
-    with open(f"config/{args.input_config}.yaml") as stream:
-        config_dict = yaml.safe_load(stream)
-
-    valid_run_types = [x for x in all_subparsers.keys() if x != "config"]
-
-    version = config_dict.pop("version", None)
-    if version and version != version("nhssynth"):
-        warnings.warn(
-            f"This config file's specified version ({version}) does not match the currently installed version of nhssynth ({version('nhssynth')}), results may differ."
-        )
-    elif not version:
-        version = ver("nhssynth")
-
-    run_type = config_dict.pop("run_type", None)
-
-    if run_type == "pipeline":
-        modules_to_run = PIPELINE
-    else:
-        modules_to_run = [x for x in config_dict.keys() | {run_type} if x in valid_run_types]
-        if not args.custom_pipeline:
-            modules_to_run = sorted(modules_to_run, key=lambda x: PIPELINE.index(x))
-
-    if not modules_to_run:
-        warnings.warn(
-            "Missing or invalid `run_type` and / or module specification hierarchy in `config/{args.input_config}.yaml`, defaulting to a full run of the pipeline"
-        )
-        modules_to_run = PIPELINE
-
-    # Get all possible default arguments by scraping the top level `parser` and the appropriate sub-parser for the `run_type`
-    args_dict, required_args = get_default_and_required_args(
-        parser, filter_dict(all_subparsers, modules_to_run, include=True)
-    )
-
-    # Find the non-default arguments amongst passed `args` by seeing which of them are different to the entries of `args_dict`
-    non_default_passed_args_dict = {
-        k: v
-        for k, v in vars(args).items()
-        if k in ["input_config", "custom_pipeline"] or (k in args_dict and k != "func" and v != args_dict[k])
-    }
-
-    # Overwrite the default arguments with the ones from the yaml file
-    args_dict.update(flatten_dict(config_dict))
-
-    # Overwrite the result of the above with any non-default CLI args
-    args_dict.update(non_default_passed_args_dict)
-
-    # Create a new Namespace using the assembled dictionary
-    new_args = argparse.Namespace(**args_dict)
-    assert getattr(
-        new_args, "dataset"
-    ), "No dataset specified in the passed config file, provide one with the `--dataset` argument or add it to the config file"
-    assert all(
-        getattr(new_args, req_arg["arg"]) for req_arg in required_args
-    ), f"Required arguments are missing from the passed config file: {[ra['module'] + ':' + ra['arg'] for ra in required_args if not getattr(new_args, ra['arg'])]}"
-
-    # Run the appropriate execution function(s)
-    if not new_args.seed:
-        warnings.warn("No seed has been specified, meaning the results of this run may not be reproducible.")
-    new_args.version = version
-    new_args.modules_to_run = modules_to_run
-    new_args.module_handover = {}
-    for module in new_args.modules_to_run:
-        MODULE_MAP[module](new_args)
-
-    return new_args
+132
+133
def read_config(
+    args: argparse.Namespace,
+    parser: argparse.ArgumentParser,
+    all_subparsers: dict[str, argparse.ArgumentParser],
+) -> argparse.Namespace:
+    """
+    Hierarchically assembles a config `argparse.Namespace` object for the inferred modules to run and execute, given a file.
+
+    1. Load the YAML file containing the config to read from
+    2. Check a valid `run_type` is specified or infer it and determine the list of `modules_to_run`
+    3. Establish the appropriate default configuration set of arguments from the `parser` and `all_subparsers` for the determined `modules_to_run`
+    4. Overwrite these with the specified (sub)set of config in the YAML file
+    5. Overwrite again with passed command-line `args` (these are considered 'overrides')
+    6. Run the appropriate module(s) or pipeline with the resulting configuration `Namespace` object
+
+    Args:
+        args: Namespace object containing arguments from the command line
+        parser: top-level `ArgumentParser` object containing common arguments
+        all_subparsers: dictionary of `ArgumentParser` objects, one for each module
+
+    Returns:
+        A Namespace object containing the assembled configuration settings
+
+    Raises:
+        AssertionError: if any required arguments are missing from the configuration file / overrides
+    """
+    # Open the passed yaml file and load into a dictionary
+    with open(f"config/{args.input_config}.yaml") as stream:
+        config_dict = yaml.safe_load(stream)
+
+    valid_run_types = [x for x in all_subparsers.keys() if x != "config"]
+
+    version = config_dict.pop("version", None)
+    if version and version != version("nhssynth"):
+        warnings.warn(
+            f"This config file's specified version ({version}) does not match the currently installed version of nhssynth ({version('nhssynth')}), results may differ."
+        )
+    elif not version:
+        version = ver("nhssynth")
+
+    run_type = config_dict.pop("run_type", None)
+
+    if run_type == "pipeline":
+        modules_to_run = PIPELINE
+    else:
+        modules_to_run = [x for x in config_dict.keys() | {run_type} if x in valid_run_types]
+        if not args.custom_pipeline:
+            modules_to_run = sorted(modules_to_run, key=lambda x: PIPELINE.index(x))
+
+    if not modules_to_run:
+        warnings.warn(
+            "Missing or invalid `run_type` and / or module specification hierarchy in `config/{args.input_config}.yaml`, defaulting to a full run of the pipeline"
+        )
+        modules_to_run = PIPELINE
+
+    # Get all possible default arguments by scraping the top level `parser` and the appropriate sub-parser for the `run_type`
+    args_dict, required_args = get_default_and_required_args(
+        parser, filter_dict(all_subparsers, modules_to_run, include=True)
+    )
+
+    # Find the non-default arguments amongst passed `args` by seeing which of them are different to the entries of `args_dict`
+    non_default_passed_args_dict = {
+        k: v
+        for k, v in vars(args).items()
+        if k in ["input_config", "custom_pipeline"] or (k in args_dict and k != "func" and v != args_dict[k])
+    }
+
+    # Overwrite the default arguments with the ones from the yaml file
+    args_dict.update(flatten_dict(config_dict))
+
+    # Overwrite the result of the above with any non-default CLI args
+    args_dict.update(non_default_passed_args_dict)
+
+    # Create a new Namespace using the assembled dictionary
+    new_args = argparse.Namespace(**args_dict)
+    assert getattr(
+        new_args, "dataset"
+    ), "No dataset specified in the passed config file, provide one with the `--dataset` argument or add it to the config file"
+    assert all(
+        getattr(new_args, req_arg["arg"]) for req_arg in required_args
+    ), f"Required arguments are missing from the passed config file: {[ra['module'] + ':' + ra['arg'] for ra in required_args if not getattr(new_args, ra['arg'])]}"
+
+    # Run the appropriate execution function(s)
+    if not new_args.seed:
+        warnings.warn("No seed has been specified, meaning the results of this run may not be reproducible.")
+    new_args.version = version
+    new_args.modules_to_run = modules_to_run
+    new_args.module_handover = {}
+    for module in new_args.modules_to_run:
+        MODULE_MAP[module](new_args)
+
+    return new_args
 
@@ -3373,8 +3373,7 @@

Source code in src/nhssynth/cli/config.py -
215
-216
+            
216
 217
 218
 219
@@ -3387,21 +3386,22 @@ 

226 227 228 -229

def write_config(
-    args: argparse.Namespace,
-    all_subparsers: dict[str, argparse.ArgumentParser],
-) -> None:
-    """
-    Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by `args.save_config_path`.
-
-    Args:
-        args: A namespace containing the run's configuration.
-        all_subparsers: A dictionary containing all subparsers for the config args.
-    """
-    experiment_name = args.experiment_name
-    args_dict = assemble_config(args, all_subparsers)
-    with open(f"experiments/{experiment_name}/config_{experiment_name}.yaml", "w") as yaml_file:
-        yaml.dump(args_dict, yaml_file, default_flow_style=False, sort_keys=False)
+229
+230
def write_config(
+    args: argparse.Namespace,
+    all_subparsers: dict[str, argparse.ArgumentParser],
+) -> None:
+    """
+    Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by `args.save_config_path`.
+
+    Args:
+        args: A namespace containing the run's configuration.
+        all_subparsers: A dictionary containing all subparsers for the config args.
+    """
+    experiment_name = args.experiment_name
+    args_dict = assemble_config(args, all_subparsers)
+    with open(f"experiments/{experiment_name}/config_{experiment_name}.yaml", "w") as yaml_file:
+        yaml.dump(args_dict, yaml_file, default_flow_style=False, sort_keys=False)
 
diff --git a/reference/cli/model_arguments/index.html b/reference/cli/model_arguments/index.html index 353b9c85..5a36a58e 100644 --- a/reference/cli/model_arguments/index.html +++ b/reference/cli/model_arguments/index.html @@ -2505,8 +2505,7 @@

Source code in src/nhssynth/cli/model_arguments.py -
 76
- 77
+            
 77
  78
  79
  80
@@ -2600,102 +2599,103 @@ 

168 169 170 -171

def add_gan_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:
-    """Adds arguments to an existing group for the GAN model."""
-    group.add_argument(
-        "--n-units-conditional",
-        type=int,
-        help="the number of units in the conditional layer",
-    )
-    group.add_argument(
-        "--generator-n-layers-hidden",
-        type=int,
-        help="the number of hidden layers in the generator",
-    )
-    group.add_argument(
-        "--generator-n-units-hidden",
-        type=int,
-        help="the number of units in each hidden layer of the generator",
-    )
-    group.add_argument(
-        "--generator-activation",
-        type=str,
-        choices=list(ACTIVATION_FUNCTIONS.keys()),
-        help="the activation function of the generator",
-    )
-    group.add_argument(
-        "--generator-batch-norm",
-        action="store_true",
-        help="whether to use batch norm in the generator",
-    )
-    group.add_argument(
-        "--generator-dropout",
-        type=float,
-        help="the dropout rate in the generator",
-    )
-    group.add_argument(
-        "--generator-lr",
-        type=float,
-        help="the learning rate for the generator",
-    )
-    group.add_argument(
-        "--generator-residual",
-        action="store_true",
-        help="whether to use residual connections in the generator",
-    )
-    group.add_argument(
-        "--generator-opt-betas",
-        type=float,
-        nargs=2,
-        help="the beta values for the generator optimizer",
-    )
-    group.add_argument(
-        "--discriminator-n-layers-hidden",
-        type=int,
-        help="the number of hidden layers in the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-n-units-hidden",
-        type=int,
-        help="the number of units in each hidden layer of the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-activation",
-        type=str,
-        choices=list(ACTIVATION_FUNCTIONS.keys()),
-        help="the activation function of the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-batch-norm",
-        action="store_true",
-        help="whether to use batch norm in the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-dropout",
-        type=float,
-        help="the dropout rate in the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-lr",
-        type=float,
-        help="the learning rate for the discriminator",
-    )
-    group.add_argument(
-        "--discriminator-opt-betas",
-        type=float,
-        nargs=2,
-        help="the beta values for the discriminator optimizer",
-    )
-    group.add_argument(
-        "--clipping-value",
-        type=float,
-        help="the clipping value for the discriminator",
-    )
-    group.add_argument(
-        "--lambda-gradient-penalty",
-        type=float,
-        help="the gradient penalty coefficient",
-    )
+171
+172
def add_gan_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:
+    """Adds arguments to an existing group for the GAN model."""
+    group.add_argument(
+        "--n-units-conditional",
+        type=int,
+        help="the number of units in the conditional layer",
+    )
+    group.add_argument(
+        "--generator-n-layers-hidden",
+        type=int,
+        help="the number of hidden layers in the generator",
+    )
+    group.add_argument(
+        "--generator-n-units-hidden",
+        type=int,
+        help="the number of units in each hidden layer of the generator",
+    )
+    group.add_argument(
+        "--generator-activation",
+        type=str,
+        choices=list(ACTIVATION_FUNCTIONS.keys()),
+        help="the activation function of the generator",
+    )
+    group.add_argument(
+        "--generator-batch-norm",
+        action="store_true",
+        help="whether to use batch norm in the generator",
+    )
+    group.add_argument(
+        "--generator-dropout",
+        type=float,
+        help="the dropout rate in the generator",
+    )
+    group.add_argument(
+        "--generator-lr",
+        type=float,
+        help="the learning rate for the generator",
+    )
+    group.add_argument(
+        "--generator-residual",
+        action="store_true",
+        help="whether to use residual connections in the generator",
+    )
+    group.add_argument(
+        "--generator-opt-betas",
+        type=float,
+        nargs=2,
+        help="the beta values for the generator optimizer",
+    )
+    group.add_argument(
+        "--discriminator-n-layers-hidden",
+        type=int,
+        help="the number of hidden layers in the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-n-units-hidden",
+        type=int,
+        help="the number of units in each hidden layer of the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-activation",
+        type=str,
+        choices=list(ACTIVATION_FUNCTIONS.keys()),
+        help="the activation function of the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-batch-norm",
+        action="store_true",
+        help="whether to use batch norm in the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-dropout",
+        type=float,
+        help="the dropout rate in the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-lr",
+        type=float,
+        help="the learning rate for the discriminator",
+    )
+    group.add_argument(
+        "--discriminator-opt-betas",
+        type=float,
+        nargs=2,
+        help="the beta values for the discriminator optimizer",
+    )
+    group.add_argument(
+        "--clipping-value",
+        type=float,
+        help="the clipping value for the discriminator",
+    )
+    group.add_argument(
+        "--lambda-gradient-penalty",
+        type=float,
+        help="the gradient penalty coefficient",
+    )
 
@@ -2719,21 +2719,21 @@

Source code in src/nhssynth/cli/model_arguments.py -
 7
- 8
+            
def add_model_specific_args(group: argparse._ArgumentGroup, name: str, overrides: bool = False) -> None:
-    """Adds arguments to an existing group according to `name`."""
-    if name == "VAE":
-        add_vae_args(group, overrides)
-    elif name == "GAN":
-        add_gan_args(group, overrides)
-    elif name == "TabularGAN":
-        add_tabular_gan_args(group, overrides)
+14
+15
def add_model_specific_args(group: argparse._ArgumentGroup, name: str, overrides: bool = False) -> None:
+    """Adds arguments to an existing group according to `name`."""
+    if name == "VAE":
+        add_vae_args(group, overrides)
+    elif name == "GAN":
+        add_gan_args(group, overrides)
+    elif name == "TabularGAN":
+        add_tabular_gan_args(group, overrides)
 
@@ -2757,8 +2757,7 @@

Source code in src/nhssynth/cli/model_arguments.py -
17
-18
+            
18
 19
 20
 21
@@ -2813,63 +2812,64 @@ 

70 71 72 -73

def add_vae_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:
-    """Adds arguments to an existing group for the VAE model."""
-    group.add_argument(
-        "--encoder-latent-dim",
-        type=int,
-        nargs="+",
-        help="the latent dimension of the encoder",
-    )
-    group.add_argument(
-        "--encoder-hidden-dim",
-        type=int,
-        nargs="+",
-        help="the hidden dimension of the encoder",
-    )
-    group.add_argument(
-        "--encoder-activation",
-        type=str,
-        nargs="+",
-        choices=list(ACTIVATION_FUNCTIONS.keys()),
-        help="the activation function of the encoder",
-    )
-    group.add_argument(
-        "--encoder-learning-rate",
-        type=float,
-        nargs="+",
-        help="the learning rate for the encoder",
-    )
-    group.add_argument(
-        "--decoder-latent-dim",
-        type=int,
-        nargs="+",
-        help="the latent dimension of the decoder",
-    )
-    group.add_argument(
-        "--decoder-hidden-dim",
-        type=int,
-        nargs="+",
-        help="the hidden dimension of the decoder",
-    )
-    group.add_argument(
-        "--decoder-activation",
-        type=str,
-        nargs="+",
-        choices=list(ACTIVATION_FUNCTIONS.keys()),
-        help="the activation function of the decoder",
-    )
-    group.add_argument(
-        "--decoder-learning-rate",
-        type=float,
-        nargs="+",
-        help="the learning rate for the decoder",
-    )
-    group.add_argument(
-        "--shared-optimizer",
-        action="store_true",
-        help="whether to use a shared optimizer for the encoder and decoder",
-    )
+73
+74
def add_vae_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:
+    """Adds arguments to an existing group for the VAE model."""
+    group.add_argument(
+        "--encoder-latent-dim",
+        type=int,
+        nargs="+",
+        help="the latent dimension of the encoder",
+    )
+    group.add_argument(
+        "--encoder-hidden-dim",
+        type=int,
+        nargs="+",
+        help="the hidden dimension of the encoder",
+    )
+    group.add_argument(
+        "--encoder-activation",
+        type=str,
+        nargs="+",
+        choices=list(ACTIVATION_FUNCTIONS.keys()),
+        help="the activation function of the encoder",
+    )
+    group.add_argument(
+        "--encoder-learning-rate",
+        type=float,
+        nargs="+",
+        help="the learning rate for the encoder",
+    )
+    group.add_argument(
+        "--decoder-latent-dim",
+        type=int,
+        nargs="+",
+        help="the latent dimension of the decoder",
+    )
+    group.add_argument(
+        "--decoder-hidden-dim",
+        type=int,
+        nargs="+",
+        help="the hidden dimension of the decoder",
+    )
+    group.add_argument(
+        "--decoder-activation",
+        type=str,
+        nargs="+",
+        choices=list(ACTIVATION_FUNCTIONS.keys()),
+        help="the activation function of the decoder",
+    )
+    group.add_argument(
+        "--decoder-learning-rate",
+        type=float,
+        nargs="+",
+        help="the learning rate for the decoder",
+    )
+    group.add_argument(
+        "--shared-optimizer",
+        action="store_true",
+        help="whether to use a shared optimizer for the encoder and decoder",
+    )
 
diff --git a/reference/cli/module_arguments/index.html b/reference/cli/module_arguments/index.html index 43779f0b..b5814a89 100644 --- a/reference/cli/module_arguments/index.html +++ b/reference/cli/module_arguments/index.html @@ -2533,8 +2533,7 @@

Source code in src/nhssynth/cli/module_arguments.py -
10
-11
+              
11
 12
 13
 14
@@ -2551,25 +2550,26 @@ 

25 26 27 -28

class AllChoicesDefault(argparse.Action):
-    """
-    Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied:
-    (i.e. user passes `--metrics` with no follow up list of metric groups => all metric groups will be executed).
-
-    Notes:
-        1) If no `option_string` is supplied: set to default value (`self.default`)
-        2) If `option_string` is supplied:
-            a) If `values` are supplied, set to list of values
-            b) If no `values` are supplied, set to `self.const`, if `self.const` is not set, set to `self.default`
-    """
-
-    def __call__(self, parser, namespace, values=None, option_string=None):
-        if values:
-            setattr(namespace, self.dest, values)
-        elif option_string:
-            setattr(namespace, self.dest, self.const if self.const else self.default)
-        else:
-            setattr(namespace, self.dest, self.default)
+28
+29
class AllChoicesDefault(argparse.Action):
+    """
+    Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied:
+    (i.e. user passes `--metrics` with no follow up list of metric groups => all metric groups will be executed).
+
+    Notes:
+        1) If no `option_string` is supplied: set to default value (`self.default`)
+        2) If `option_string` is supplied:
+            a) If `values` are supplied, set to list of values
+            b) If no `values` are supplied, set to `self.const`, if `self.const` is not set, set to `self.default`
+    """
+
+    def __call__(self, parser, namespace, values=None, option_string=None):
+        if values:
+            setattr(namespace, self.dest, values)
+        elif option_string:
+            setattr(namespace, self.dest, self.const if self.const else self.default)
+        else:
+            setattr(namespace, self.dest, self.default)
 
@@ -2612,8 +2612,7 @@

Source code in src/nhssynth/cli/module_arguments.py -
31
-32
+            
32
 33
 34
 35
@@ -2655,50 +2654,51 @@ 

71 72 73 -74

def add_dataloader_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
-    """Adds arguments to an existing dataloader module sub-parser instance."""
-    group = parser.add_argument_group(title=group_title)
-    group.add_argument(
-        "--data-dir",
-        type=str,
-        default="./data",
-        help="the directory containing the chosen dataset",
-    )
-    group.add_argument(
-        "--index-col",
-        default=None,
-        nargs="*",
-        help="indicate the name of the index column(s) in the csv file, such that pandas can index by it",
-    )
-    group.add_argument(
-        "--constraint-graph",
-        type=str,
-        default="_constraint_graph",
-        help="the name of the html file to write the constraint graph to, defaults to `<DATASET>_constraint_graph`",
-    )
-    group.add_argument(
-        "--collapse-yaml",
-        action="store_true",
-        help="use aliases and anchors in the output metadata yaml, this will make it much more compact",
-    )
-    group.add_argument(
-        "--missingness",
-        type=str,
-        default="augment",
-        choices=MISSINGNESS_STRATEGIES,
-        help="how to handle missing values in the dataset",
-    )
-    group.add_argument(
-        "--impute",
-        type=str,
-        default=None,
-        help="the imputation strategy to use, ONLY USED if <MISSINGNESS> is set to 'impute', choose from: 'mean', 'median', 'mode', or any specific value (e.g. '0')",
-    )
-    group.add_argument(
-        "--write-csv",
-        action="store_true",
-        help="write the transformed real data to a csv file",
-    )
+74
+75
def add_dataloader_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
+    """Adds arguments to an existing dataloader module sub-parser instance."""
+    group = parser.add_argument_group(title=group_title)
+    group.add_argument(
+        "--data-dir",
+        type=str,
+        default="./data",
+        help="the directory containing the chosen dataset",
+    )
+    group.add_argument(
+        "--index-col",
+        default=None,
+        nargs="*",
+        help="indicate the name of the index column(s) in the csv file, such that pandas can index by it",
+    )
+    group.add_argument(
+        "--constraint-graph",
+        type=str,
+        default="_constraint_graph",
+        help="the name of the html file to write the constraint graph to, defaults to `<DATASET>_constraint_graph`",
+    )
+    group.add_argument(
+        "--collapse-yaml",
+        action="store_true",
+        help="use aliases and anchors in the output metadata yaml, this will make it much more compact",
+    )
+    group.add_argument(
+        "--missingness",
+        type=str,
+        default="augment",
+        choices=MISSINGNESS_STRATEGIES,
+        help="how to handle missing values in the dataset",
+    )
+    group.add_argument(
+        "--impute",
+        type=str,
+        default=None,
+        help="the imputation strategy to use, ONLY USED if <MISSINGNESS> is set to 'impute', choose from: 'mean', 'median', 'mode', or any specific value (e.g. '0')",
+    )
+    group.add_argument(
+        "--write-csv",
+        action="store_true",
+        help="write the transformed real data to a csv file",
+    )
 
@@ -2722,8 +2722,7 @@

Source code in src/nhssynth/cli/module_arguments.py -
181
-182
+            
182
 183
 184
 185
@@ -2778,63 +2777,64 @@ 

234 235 236 -237

def add_evaluation_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
-    """Adds arguments to an existing evaluation module sub-parser instance."""
-    group = parser.add_argument_group(title=group_title)
-    group.add_argument(
-        "--downstream-tasks",
-        "--tasks",
-        action="store_true",
-        help="run the downstream tasks evaluation",
-    )
-    group.add_argument(
-        "--tasks-dir",
-        type=str,
-        default="./tasks",
-        help="the directory containing the downstream tasks to run, this directory must contain a folder called <DATASET> containing the tasks to run",
-    )
-    group.add_argument(
-        "--aequitas",
-        action="store_true",
-        help="run the aequitas fairness evaluation (note this runs for each of the downstream tasks)",
-    )
-    group.add_argument(
-        "--aequitas-attributes",
-        type=str,
-        nargs="+",
-        default=None,
-        help="the attributes to use for the aequitas fairness evaluation, defaults to all attributes",
-    )
-    group.add_argument(
-        "--key-numerical-fields",
-        type=str,
-        nargs="+",
-        default=None,
-        help="the numerical key field attributes to use for SDV privacy evaluations",
-    )
-    group.add_argument(
-        "--sensitive-numerical-fields",
-        type=str,
-        nargs="+",
-        default=None,
-        help="the numerical sensitive field attributes to use for SDV privacy evaluations",
-    )
-    group.add_argument(
-        "--key-categorical-fields",
-        type=str,
-        nargs="+",
-        default=None,
-        help="the categorical key field attributes to use for SDV privacy evaluations",
-    )
-    group.add_argument(
-        "--sensitive-categorical-fields",
-        type=str,
-        nargs="+",
-        default=None,
-        help="the categorical sensitive field attributes to use for SDV privacy evaluations",
-    )
-    for name in METRIC_CHOICES:
-        generate_evaluation_arg(group, name)
+237
+238
def add_evaluation_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
+    """Adds arguments to an existing evaluation module sub-parser instance."""
+    group = parser.add_argument_group(title=group_title)
+    group.add_argument(
+        "--downstream-tasks",
+        "--tasks",
+        action="store_true",
+        help="run the downstream tasks evaluation",
+    )
+    group.add_argument(
+        "--tasks-dir",
+        type=str,
+        default="./tasks",
+        help="the directory containing the downstream tasks to run, this directory must contain a folder called <DATASET> containing the tasks to run",
+    )
+    group.add_argument(
+        "--aequitas",
+        action="store_true",
+        help="run the aequitas fairness evaluation (note this runs for each of the downstream tasks)",
+    )
+    group.add_argument(
+        "--aequitas-attributes",
+        type=str,
+        nargs="+",
+        default=None,
+        help="the attributes to use for the aequitas fairness evaluation, defaults to all attributes",
+    )
+    group.add_argument(
+        "--key-numerical-fields",
+        type=str,
+        nargs="+",
+        default=None,
+        help="the numerical key field attributes to use for SDV privacy evaluations",
+    )
+    group.add_argument(
+        "--sensitive-numerical-fields",
+        type=str,
+        nargs="+",
+        default=None,
+        help="the numerical sensitive field attributes to use for SDV privacy evaluations",
+    )
+    group.add_argument(
+        "--key-categorical-fields",
+        type=str,
+        nargs="+",
+        default=None,
+        help="the categorical key field attributes to use for SDV privacy evaluations",
+    )
+    group.add_argument(
+        "--sensitive-categorical-fields",
+        type=str,
+        nargs="+",
+        default=None,
+        help="the categorical sensitive field attributes to use for SDV privacy evaluations",
+    )
+    for name in METRIC_CHOICES:
+        generate_evaluation_arg(group, name)
 
@@ -2858,8 +2858,7 @@

Source code in src/nhssynth/cli/module_arguments.py -
 81
- 82
+            
 82
  83
  84
  85
@@ -2942,91 +2941,92 @@ 

162 163 164 -165

def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
-    """Adds arguments to an existing model module sub-parser instance."""
-    group = parser.add_argument_group(title=group_title)
-    group.add_argument(
-        "--architecture",
-        type=str,
-        nargs="+",
-        default=["VAE"],
-        choices=MODELS,
-        help="the model architecture(s) to train",
-    )
-    group.add_argument(
-        "--repeats",
-        type=int,
-        default=1,
-        help="how many times to repeat the training process per model architecture (<SEED> is incremented each time)",
-    )
-    group.add_argument(
-        "--batch-size",
-        type=int,
-        nargs="+",
-        default=32,
-        help="the batch size for the model",
-    )
-    group.add_argument(
-        "--num-epochs",
-        type=int,
-        nargs="+",
-        default=100,
-        help="number of epochs to train for",
-    )
-    group.add_argument(
-        "--patience",
-        type=int,
-        nargs="+",
-        default=5,
-        help="how many epochs the model is allowed to train for without improvement",
-    )
-    group.add_argument(
-        "--displayed-metrics",
-        type=str,
-        nargs="+",
-        default=[],
-        help="metrics to display during training of the model, when set to `None`, all metrics are displayed",
-    )
-    group.add_argument(
-        "--use-gpu",
-        action="store_true",
-        help="use the GPU for training",
-    )
-    group.add_argument(
-        "--num-samples",
-        type=int,
-        default=None,
-        help="the number of samples to generate from the model, defaults to the size of the original dataset",
-    )
-    privacy_group = parser.add_argument_group(title="model privacy options")
-    privacy_group.add_argument(
-        "--target-epsilon",
-        type=float,
-        nargs="+",
-        default=1.0,
-        help="the target epsilon for differential privacy",
-    )
-    privacy_group.add_argument(
-        "--target-delta",
-        type=float,
-        nargs="+",
-        help="the target delta for differential privacy, defaults to `1 / len(dataset)` if not specified",
-    )
-    privacy_group.add_argument(
-        "--max-grad-norm",
-        type=float,
-        nargs="+",
-        default=5.0,
-        help="the clipping threshold for gradients (only relevant under differential privacy)",
-    )
-    privacy_group.add_argument(
-        "--secure-mode",
-        action="store_true",
-        help="Enable secure RNG via the `csprng` package to make privacy guarantees more robust, comes at a cost of performance and reproducibility",
-    )
-    for model_name in MODELS.keys():
-        model_group = parser.add_argument_group(title=f"{model_name}-specific options")
-        add_model_specific_args(model_group, model_name, overrides=overrides)
+165
+166
def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
+    """Adds arguments to an existing model module sub-parser instance."""
+    group = parser.add_argument_group(title=group_title)
+    group.add_argument(
+        "--architecture",
+        type=str,
+        nargs="+",
+        default=["VAE"],
+        choices=MODELS,
+        help="the model architecture(s) to train",
+    )
+    group.add_argument(
+        "--repeats",
+        type=int,
+        default=1,
+        help="how many times to repeat the training process per model architecture (<SEED> is incremented each time)",
+    )
+    group.add_argument(
+        "--batch-size",
+        type=int,
+        nargs="+",
+        default=32,
+        help="the batch size for the model",
+    )
+    group.add_argument(
+        "--num-epochs",
+        type=int,
+        nargs="+",
+        default=100,
+        help="number of epochs to train for",
+    )
+    group.add_argument(
+        "--patience",
+        type=int,
+        nargs="+",
+        default=5,
+        help="how many epochs the model is allowed to train for without improvement",
+    )
+    group.add_argument(
+        "--displayed-metrics",
+        type=str,
+        nargs="+",
+        default=[],
+        help="metrics to display during training of the model, when set to `None`, all metrics are displayed",
+    )
+    group.add_argument(
+        "--use-gpu",
+        action="store_true",
+        help="use the GPU for training",
+    )
+    group.add_argument(
+        "--num-samples",
+        type=int,
+        default=None,
+        help="the number of samples to generate from the model, defaults to the size of the original dataset",
+    )
+    privacy_group = parser.add_argument_group(title="model privacy options")
+    privacy_group.add_argument(
+        "--target-epsilon",
+        type=float,
+        nargs="+",
+        default=1.0,
+        help="the target epsilon for differential privacy",
+    )
+    privacy_group.add_argument(
+        "--target-delta",
+        type=float,
+        nargs="+",
+        help="the target delta for differential privacy, defaults to `1 / len(dataset)` if not specified",
+    )
+    privacy_group.add_argument(
+        "--max-grad-norm",
+        type=float,
+        nargs="+",
+        default=5.0,
+        help="the clipping threshold for gradients (only relevant under differential privacy)",
+    )
+    privacy_group.add_argument(
+        "--secure-mode",
+        action="store_true",
+        help="Enable secure RNG via the `csprng` package to make privacy guarantees more robust, comes at a cost of performance and reproducibility",
+    )
+    for model_name in MODELS.keys():
+        model_group = parser.add_argument_group(title=f"{model_name}-specific options")
+        add_model_specific_args(model_group, model_name, overrides=overrides)
 
@@ -3050,8 +3050,7 @@

Source code in src/nhssynth/cli/module_arguments.py -
240
-241
+            
241
 242
 243
 244
@@ -3072,29 +3071,30 @@ 

259 260 261 -262

def add_plotting_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
-    """Adds arguments to an existing plotting module sub-parser instance."""
-    group = parser.add_argument_group(title=group_title)
-    group.add_argument(
-        "--plot-quality",
-        action="store_true",
-        help="plot the SDV quality report",
-    )
-    group.add_argument(
-        "--plot-diagnostic",
-        action="store_true",
-        help="plot the SDV diagnostic report",
-    )
-    group.add_argument(
-        "--plot-sdv-report",
-        action="store_true",
-        help="plot the SDV report",
-    )
-    group.add_argument(
-        "--plot-tsne",
-        action="store_true",
-        help="plot the t-SNE embeddings of the real and synthetic data",
-    )
+262
+263
def add_plotting_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:
+    """Adds arguments to an existing plotting module sub-parser instance."""
+    group = parser.add_argument_group(title=group_title)
+    group.add_argument(
+        "--plot-quality",
+        action="store_true",
+        help="plot the SDV quality report",
+    )
+    group.add_argument(
+        "--plot-diagnostic",
+        action="store_true",
+        help="plot the SDV diagnostic report",
+    )
+    group.add_argument(
+        "--plot-sdv-report",
+        action="store_true",
+        help="plot the SDV report",
+    )
+    group.add_argument(
+        "--plot-tsne",
+        action="store_true",
+        help="plot the t-SNE embeddings of the real and synthetic data",
+    )
 
diff --git a/reference/cli/module_setup/index.html b/reference/cli/module_setup/index.html index d1efd388..94e8e197 100644 --- a/reference/cli/module_setup/index.html +++ b/reference/cli/module_setup/index.html @@ -2597,8 +2597,7 @@

Source code in src/nhssynth/cli/module_setup.py -
10
-11
+              
11
 12
 13
 14
@@ -2632,42 +2631,43 @@ 

42 43 44 -45

class ModuleConfig:
-    """
-    Represents a module's configuration, containing the following attributes:
-
-    Attributes:
-        func: A callable that executes the module's functionality.
-        add_args: A callable that populates the module's sub-parser arguments.
-        description: A description of the module's functionality.
-        help: A help message for the module's command-line interface.
-        common_parsers: A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.
-    """
-
-    def __init__(
-        self,
-        func: Callable[..., argparse.Namespace],
-        add_args: Callable[..., None],
-        description: str,
-        help: str,
-        common_parsers: Optional[list[str]] = None,
-        no_seed: bool = False,
-    ) -> None:
-        self.func = func
-        self.add_args = add_args
-        self.description = description
-        self.help = help
-        self.common_parsers = ["core", "seed"] if not no_seed else ["core"]
-        if common_parsers:
-            assert set(common_parsers) <= COMMON_PARSERS.keys(), "Invalid common parser(s) specified."
-            # merge the below two assert statements
-            assert (
-                "core" not in common_parsers and "seed" not in common_parsers
-            ), "The 'seed' and 'core' parser groups are automatically added to all modules, remove the from `ModuleConfig`s."
-            self.common_parsers += common_parsers
-
-    def __call__(self, args: argparse.Namespace) -> argparse.Namespace:
-        return self.func(args)
+45
+46
class ModuleConfig:
+    """
+    Represents a module's configuration, containing the following attributes:
+
+    Attributes:
+        func: A callable that executes the module's functionality.
+        add_args: A callable that populates the module's sub-parser arguments.
+        description: A description of the module's functionality.
+        help: A help message for the module's command-line interface.
+        common_parsers: A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.
+    """
+
+    def __init__(
+        self,
+        func: Callable[..., argparse.Namespace],
+        add_args: Callable[..., None],
+        description: str,
+        help: str,
+        common_parsers: Optional[list[str]] = None,
+        no_seed: bool = False,
+    ) -> None:
+        self.func = func
+        self.add_args = add_args
+        self.description = description
+        self.help = help
+        self.common_parsers = ["core", "seed"] if not no_seed else ["core"]
+        if common_parsers:
+            assert set(common_parsers) <= COMMON_PARSERS.keys(), "Invalid common parser(s) specified."
+            # merge the below two assert statements
+            assert (
+                "core" not in common_parsers and "seed" not in common_parsers
+            ), "The 'seed' and 'core' parser groups are automatically added to all modules, remove the from `ModuleConfig`s."
+            self.common_parsers += common_parsers
+
+    def __call__(self, args: argparse.Namespace) -> argparse.Namespace:
+        return self.func(args)
 
@@ -2710,8 +2710,7 @@

Source code in src/nhssynth/cli/module_setup.py -
62
-63
+            
63
 64
 65
 66
@@ -2727,24 +2726,25 @@ 

76 77 78 -79

def add_config_args(parser: argparse.ArgumentParser) -> None:
-    """Adds arguments to `parser` relating to configuration file handling and module-specific config overrides."""
-    parser.add_argument(
-        "-c",
-        "--input-config",
-        required=True,
-        help="specify the config file name",
-    )
-    parser.add_argument(
-        "-cp",
-        "--custom-pipeline",
-        action="store_true",
-        help="infer a custom pipeline running order of modules from the config",
-    )
-    for module_name in PIPELINE:
-        MODULE_MAP[module_name].add_args(parser, f"{module_name} option overrides", overrides=True)
-    for module_name in VALID_MODULES - set(PIPELINE):
-        MODULE_MAP[module_name].add_args(parser, f"{module_name} options overrides", overrides=True)
+79
+80
def add_config_args(parser: argparse.ArgumentParser) -> None:
+    """Adds arguments to `parser` relating to configuration file handling and module-specific config overrides."""
+    parser.add_argument(
+        "-c",
+        "--input-config",
+        required=True,
+        help="specify the config file name",
+    )
+    parser.add_argument(
+        "-cp",
+        "--custom-pipeline",
+        action="store_true",
+        help="infer a custom pipeline running order of modules from the config",
+    )
+    for module_name in PIPELINE:
+        MODULE_MAP[module_name].add_args(parser, f"{module_name} option overrides", overrides=True)
+    for module_name in VALID_MODULES - set(PIPELINE):
+        MODULE_MAP[module_name].add_args(parser, f"{module_name} options overrides", overrides=True)
 
@@ -2768,13 +2768,13 @@

Source code in src/nhssynth/cli/module_setup.py -
56
-57
+            
def add_pipeline_args(parser: argparse.ArgumentParser) -> None:
-    """Adds arguments to `parser` for each module in the pipeline."""
-    for module_name in PIPELINE:
-        MODULE_MAP[module_name].add_args(parser, f"{module_name} options")
+59
+60
def add_pipeline_args(parser: argparse.ArgumentParser) -> None:
+    """Adds arguments to `parser` for each module in the pipeline."""
+    for module_name in PIPELINE:
+        MODULE_MAP[module_name].add_args(parser, f"{module_name} options")
 
@@ -2880,8 +2880,7 @@

Source code in src/nhssynth/cli/module_setup.py -
167
-168
+            
168
 169
 170
 171
@@ -2909,36 +2908,37 @@ 

193 194 195 -196

def add_subparser(
-    subparsers: argparse._SubParsersAction,
-    name: str,
-    module_config: ModuleConfig,
-) -> argparse.ArgumentParser:
-    """
-    Add a subparser to an argparse argument parser.
-
-    Args:
-        subparsers: The subparsers action to which the subparser will be added.
-        name: The name of the subparser.
-        module_config: A [`ModuleConfig`][nhssynth.cli.module_setup.ModuleConfig] object containing information about the subparser, including a function to execute and a function to add arguments.
-
-    Returns:
-        The newly created subparser.
-    """
-    parent_parsers = get_parent_parsers(name, module_config.common_parsers)
-    parser = subparsers.add_parser(
-        name=name,
-        description=module_config.description,
-        help=module_config.help,
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-        parents=parent_parsers,
-    )
-    if name not in {"pipeline", "config"}:
-        module_config.add_args(parser, f"{name} options")
-    else:
-        module_config.add_args(parser)
-    parser.set_defaults(func=module_config.func)
-    return parser
+196
+197
def add_subparser(
+    subparsers: argparse._SubParsersAction,
+    name: str,
+    module_config: ModuleConfig,
+) -> argparse.ArgumentParser:
+    """
+    Add a subparser to an argparse argument parser.
+
+    Args:
+        subparsers: The subparsers action to which the subparser will be added.
+        name: The name of the subparser.
+        module_config: A [`ModuleConfig`][nhssynth.cli.module_setup.ModuleConfig] object containing information about the subparser, including a function to execute and a function to add arguments.
+
+    Returns:
+        The newly created subparser.
+    """
+    parent_parsers = get_parent_parsers(name, module_config.common_parsers)
+    parser = subparsers.add_parser(
+        name=name,
+        description=module_config.description,
+        help=module_config.help,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        parents=parent_parsers,
+    )
+    if name not in {"pipeline", "config"}:
+        module_config.add_args(parser, f"{name} options")
+    else:
+        module_config.add_args(parser)
+    parser.set_defaults(func=module_config.func)
+    return parser
 
@@ -2962,21 +2962,21 @@

Source code in src/nhssynth/cli/module_setup.py -
157
-158
+            
def get_parent_parsers(name: str, module_parsers: list[str]) -> list[argparse.ArgumentParser]:
-    """Get a list of parent parsers for a given module, based on the module's `common_parsers` attribute."""
-    if name in {"pipeline", "config"}:
-        return [p(name == "config") for p in COMMON_PARSERS.values()]
-    elif name == "dashboard":
-        return [COMMON_PARSERS[pn](True) for pn in module_parsers]
-    else:
-        return [COMMON_PARSERS[pn]() for pn in module_parsers]
+164
+165
def get_parent_parsers(name: str, module_parsers: list[str]) -> list[argparse.ArgumentParser]:
+    """Get a list of parent parsers for a given module, based on the module's `common_parsers` attribute."""
+    if name in {"pipeline", "config"}:
+        return [p(name == "config") for p in COMMON_PARSERS.values()]
+    elif name == "dashboard":
+        return [COMMON_PARSERS[pn](True) for pn in module_parsers]
+    else:
+        return [COMMON_PARSERS[pn]() for pn in module_parsers]
 
@@ -3000,17 +3000,17 @@

Source code in src/nhssynth/cli/module_setup.py -
48
-49
+            
def run_pipeline(args: argparse.Namespace) -> None:
-    """Runs the specified pipeline of modules with the passed configuration `args`."""
-    print("Running full pipeline...")
-    args.modules_to_run = PIPELINE
-    for module_name in PIPELINE:
-        args = MODULE_MAP[module_name](args)
+53
+54
def run_pipeline(args: argparse.Namespace) -> None:
+    """Runs the specified pipeline of modules with the passed configuration `args`."""
+    print("Running full pipeline...")
+    args.modules_to_run = PIPELINE
+    for module_name in PIPELINE:
+        args = MODULE_MAP[module_name](args)
 
diff --git a/reference/common/common/index.html b/reference/common/common/index.html index e6a95c6c..42968cd6 100644 --- a/reference/common/common/index.html +++ b/reference/common/common/index.html @@ -2517,8 +2517,7 @@

Source code in src/nhssynth/common/common.py -
 9
-10
+            
10
 11
 12
 13
@@ -2527,17 +2526,18 @@ 

16 17 18 -19

def set_seed(seed: Optional[int] = None) -> None:
-    """
-    (Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.
-
-    Args:
-        seed: The seed to set.
-    """
-    if seed:
-        np.random.seed(seed)
-        torch.manual_seed(seed)
-        random.seed(seed)
+19
+20
def set_seed(seed: Optional[int] = None) -> None:
+    """
+    (Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.
+
+    Args:
+        seed: The seed to set.
+    """
+    if seed:
+        np.random.seed(seed)
+        torch.manual_seed(seed)
+        random.seed(seed)
 
diff --git a/reference/common/dicts/index.html b/reference/common/dicts/index.html index 5662fc95..2c1e122e 100644 --- a/reference/common/dicts/index.html +++ b/reference/common/dicts/index.html @@ -2597,8 +2597,7 @@

Source code in src/nhssynth/common/dicts.py -
 5
- 6
+            
 6
  7
  8
  9
@@ -2620,30 +2619,31 @@ 

25 26 27 -28

def filter_dict(d: dict, filter_keys: Union[set, list], include: bool = False) -> dict:
-    """
-    Given a dictionary, return a new dictionary either including or excluding keys in a given `filter` set.
-
-    Args:
-        d: A dictionary to filter.
-        filter_keys: A list or set of keys to either include or exclude.
-        include: Determine whether to return a dictionary including or excluding keys in `filter`.
-
-    Returns:
-        A filtered dictionary.
-
-    Examples:
-        >>> d = {'a': 1, 'b': 2, 'c': 3}
-        >>> filter_dict(d, {'a', 'b'})
-        {'c': 3}
-        >>> filter_dict(d, {'a', 'b'}, include=True)
-        {'a': 1, 'b': 2}
-    """
-    if include:
-        filtered_keys = set(filter_keys) & set(d.keys())
-    else:
-        filtered_keys = set(d.keys()) - set(filter_keys)
-    return {k: v for k, v in d.items() if k in filtered_keys}
+28
+29
def filter_dict(d: dict, filter_keys: Union[set, list], include: bool = False) -> dict:
+    """
+    Given a dictionary, return a new dictionary either including or excluding keys in a given `filter` set.
+
+    Args:
+        d: A dictionary to filter.
+        filter_keys: A list or set of keys to either include or exclude.
+        include: Determine whether to return a dictionary including or excluding keys in `filter`.
+
+    Returns:
+        A filtered dictionary.
+
+    Examples:
+        >>> d = {'a': 1, 'b': 2, 'c': 3}
+        >>> filter_dict(d, {'a', 'b'})
+        {'c': 3}
+        >>> filter_dict(d, {'a', 'b'}, include=True)
+        {'a': 1, 'b': 2}
+    """
+    if include:
+        filtered_keys = set(filter_keys) & set(d.keys())
+    else:
+        filtered_keys = set(d.keys()) - set(filter_keys)
+    return {k: v for k, v in d.items() if k in filtered_keys}
 
@@ -2753,8 +2753,7 @@

Source code in src/nhssynth/common/dicts.py -
56
-57
+            
57
 58
 59
 60
@@ -2779,33 +2778,34 @@ 

79 80 81 -82

def flatten_dict(d: dict[str, Any]) -> dict[str, Any]:
-    """
-    Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.
-
-    Args:
-        d: A dictionary with potentially nested keys.
-
-    Returns:
-        A flattened dictionary.
-
-    Raises:
-        ValueError: If duplicate keys are found in the flattened dictionary.
-
-    Examples:
-        >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
-        >>> flatten_dict(d)
-        {'a': 1, 'c': 2, 'e': 3}
-    """
-    items = []
-    for k, v in d.items():
-        if isinstance(v, dict):
-            items.extend(flatten_dict(v).items())
-        else:
-            items.append((k, v))
-    if len(set([p[0] for p in items])) != len(items):
-        raise ValueError("Duplicate keys found in flattened dictionary")
-    return dict(items)
+82
+83
def flatten_dict(d: dict[str, Any]) -> dict[str, Any]:
+    """
+    Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.
+
+    Args:
+        d: A dictionary with potentially nested keys.
+
+    Returns:
+        A flattened dictionary.
+
+    Raises:
+        ValueError: If duplicate keys are found in the flattened dictionary.
+
+    Examples:
+        >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
+        >>> flatten_dict(d)
+        {'a': 1, 'c': 2, 'e': 3}
+    """
+    items = []
+    for k, v in d.items():
+        if isinstance(v, dict):
+            items.extend(flatten_dict(v).items())
+        else:
+            items.append((k, v))
+    if len(set([p[0] for p in items])) != len(items):
+        raise ValueError("Duplicate keys found in flattened dictionary")
+    return dict(items)
 
@@ -2907,8 +2907,7 @@

Source code in src/nhssynth/common/dicts.py -
31
-32
+            
32
 33
 34
 35
@@ -2929,29 +2928,30 @@ 

50 51 52 -53

def get_key_by_value(d: dict, value: Any) -> Union[Any, None]:
-    """
-    Find the first key in a dictionary with a given value.
-
-    Args:
-        d: A dictionary to search through.
-        value: The value to search for.
-
-    Returns:
-        The first key in `d` with the value `value`, or `None` if no such key exists.
-
-    Examples:
-        >>> d = {'a': 1, 'b': 2, 'c': 1}
-        >>> get_key_by_value(d, 2)
-        'b'
-        >>> get_key_by_value(d, 3)
-        None
-
-    """
-    for key, val in d.items():
-        if val == value:
-            return key
-    return None
+53
+54
def get_key_by_value(d: dict, value: Any) -> Union[Any, None]:
+    """
+    Find the first key in a dictionary with a given value.
+
+    Args:
+        d: A dictionary to search through.
+        value: The value to search for.
+
+    Returns:
+        The first key in `d` with the value `value`, or `None` if no such key exists.
+
+    Examples:
+        >>> d = {'a': 1, 'b': 2, 'c': 1}
+        >>> get_key_by_value(d, 2)
+        'b'
+        >>> get_key_by_value(d, 3)
+        None
+
+    """
+    for key, val in d.items():
+        if val == value:
+            return key
+    return None
 
diff --git a/reference/common/io/index.html b/reference/common/io/index.html index 859eaef8..a3307984 100644 --- a/reference/common/io/index.html +++ b/reference/common/io/index.html @@ -2609,8 +2609,7 @@

Source code in src/nhssynth/common/io.py -
81
-82
+            
82
 83
 84
 85
@@ -2622,20 +2621,21 @@ 

91 92 93 -94

def check_exists(fns: list[str], dir: Path) -> None:
-    """
-    Checks if the files in `fns` exist in `dir`.
-
-    Args:
-        fns: The list of files to check.
-        dir: The directory the files should exist in.
-
-    Raises:
-        FileNotFoundError: If any of the files in `fns` do not exist in `dir`.
-    """
-    for fn in fns:
-        if not (dir / fn).exists():
-            raise FileNotFoundError(f"File {fn} does not exist at {dir}.")
+94
+95
def check_exists(fns: list[str], dir: Path) -> None:
+    """
+    Checks if the files in `fns` exist in `dir`.
+
+    Args:
+        fns: The list of files to check.
+        dir: The directory the files should exist in.
+
+    Raises:
+        FileNotFoundError: If any of the files in `fns` do not exist in `dir`.
+    """
+    for fn in fns:
+        if not (dir / fn).exists():
+            raise FileNotFoundError(f"File {fn} does not exist at {dir}.")
 
@@ -2741,8 +2741,7 @@

Source code in src/nhssynth/common/io.py -
23
-24
+            
24
 25
 26
 27
@@ -2754,20 +2753,21 @@ 

33 34 35 -36

def consistent_ending(fn: str, ending: str = ".pkl", suffix: str = "") -> str:
-    """
-    Ensures that the filename `fn` ends with `ending`. If not, removes any existing ending and appends `ending`.
-
-    Args:
-        fn: The filename to check.
-        ending: The desired ending to check for. Default is ".pkl".
-        suffix: A suffix to append to the filename before the ending.
-
-    Returns:
-        The filename with the correct ending and potentially an inserted suffix.
-    """
-    path_fn = Path(fn)
-    return str(path_fn.parent / path_fn.stem) + ("_" if suffix else "") + suffix + ending
+36
+37
def consistent_ending(fn: str, ending: str = ".pkl", suffix: str = "") -> str:
+    """
+    Ensures that the filename `fn` ends with `ending`. If not, removes any existing ending and appends `ending`.
+
+    Args:
+        fn: The filename to check.
+        ending: The desired ending to check for. Default is ".pkl".
+        suffix: A suffix to append to the filename before the ending.
+
+    Returns:
+        The filename with the correct ending and potentially an inserted suffix.
+    """
+    path_fn = Path(fn)
+    return str(path_fn.parent / path_fn.stem) + ("_" if suffix else "") + suffix + ending
 
@@ -2845,8 +2845,7 @@

Source code in src/nhssynth/common/io.py -
39
-40
+            
40
 41
 42
 43
@@ -2855,17 +2854,18 @@ 

46 47 48 -49

def consistent_endings(args: list[Union[str, tuple[str, str], tuple[str, str, str]]]) -> list[str]:
-    """
-    Wrapper around `consistent_ending` to apply it to a list of filenames.
-
-    Args:
-        args: The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.
-
-    Returns:
-        The list of filenames with the correct endings.
-    """
-    return list(consistent_ending(arg) if isinstance(arg, str) else consistent_ending(*arg) for arg in args)
+49
+50
def consistent_endings(args: list[Union[str, tuple[str, str], tuple[str, str, str]]]) -> list[str]:
+    """
+    Wrapper around `consistent_ending` to apply it to a list of filenames.
+
+    Args:
+        args: The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.
+
+    Returns:
+        The list of filenames with the correct endings.
+    """
+    return list(consistent_ending(arg) if isinstance(arg, str) else consistent_ending(*arg) for arg in args)
 
@@ -2957,8 +2957,7 @@

Source code in src/nhssynth/common/io.py -
 7
- 8
+            
 8
  9
 10
 11
@@ -2970,20 +2969,21 @@ 

17 18 19 -20

def experiment_io(experiment_name: str, dir_experiments: str = "experiments") -> str:
-    """
-    Create an experiment's directory and return the path.
-
-    Args:
-        experiment_name: The name of the experiment.
-        dir_experiments: The name of the directory containing all experiments.
-
-    Returns:
-        The path to the experiment directory.
-    """
-    dir_experiment = Path(dir_experiments) / experiment_name
-    dir_experiment.mkdir(parents=True, exist_ok=True)
-    return dir_experiment
+20
+21
def experiment_io(experiment_name: str, dir_experiments: str = "experiments") -> str:
+    """
+    Create an experiment's directory and return the path.
+
+    Args:
+        experiment_name: The name of the experiment.
+        dir_experiments: The name of the directory containing all experiments.
+
+    Returns:
+        The path to the experiment directory.
+    """
+    dir_experiment = Path(dir_experiments) / experiment_name
+    dir_experiment.mkdir(parents=True, exist_ok=True)
+    return dir_experiment
 
@@ -3075,8 +3075,7 @@

Source code in src/nhssynth/common/io.py -
52
-53
+            
53
 54
 55
 56
@@ -3090,22 +3089,23 @@ 

64 65 66 -67

def potential_suffix(fn: str, fn_base: str) -> str:
-    """
-    Checks if `fn` is a suffix (starts with an underscore) to append to `fn_base`, or a filename in its own right.
-
-    Args:
-        fn: The filename / potential suffix to append to `fn_base`.
-        fn_base: The name of the file the suffix would attach to.
-
-    Returns:
-        The appropriately processed `fn`
-    """
-    fn_base = Path(fn_base).stem
-    if fn[0] == "_":
-        return fn_base + fn
-    else:
-        return fn
+67
+68
def potential_suffix(fn: str, fn_base: str) -> str:
+    """
+    Checks if `fn` is a suffix (starts with an underscore) to append to `fn_base`, or a filename in its own right.
+
+    Args:
+        fn: The filename / potential suffix to append to `fn_base`.
+        fn_base: The name of the file the suffix would attach to.
+
+    Returns:
+        The appropriately processed `fn`
+    """
+    fn_base = Path(fn_base).stem
+    if fn[0] == "_":
+        return fn_base + fn
+    else:
+        return fn
 
@@ -3173,23 +3173,23 @@

Source code in src/nhssynth/common/io.py -
70
-71
+            
def potential_suffixes(fns: list[str], fn_base: str) -> list[str]:
-    """
-    Wrapper around `potential_suffix` to apply it to a list of filenames.
-
-    Args:
-        fns: The list of filenames / potential suffixes to append to `fn_base`.
-        fn_base: The name of the file the suffixes would attach to.
-    """
-    return list(potential_suffix(fn, fn_base) for fn in fns)
+78
+79
def potential_suffixes(fns: list[str], fn_base: str) -> list[str]:
+    """
+    Wrapper around `potential_suffix` to apply it to a list of filenames.
+
+    Args:
+        fns: The list of filenames / potential suffixes to append to `fn_base`.
+        fn_base: The name of the file the suffixes would attach to.
+    """
+    return list(potential_suffix(fn, fn_base) for fn in fns)
 
@@ -3281,8 +3281,7 @@

Source code in src/nhssynth/common/io.py -
 97
- 98
+            
 98
  99
 100
 101
@@ -3297,23 +3296,24 @@ 

110 111 112 -113

def warn_if_path_supplied(fns: list[str], dir: Path) -> None:
-    """
-    Warns if the files in `fns` include directory separators.
-
-    Args:
-        fns: The list of files to check.
-        dir: The directory the files should exist in.
-
-    Warnings:
-        UserWarning: when the path to any of the files in `fns` includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.
-    """
-    for fn in fns:
-        if "/" in fn:
-            warnings.warn(
-                f"Using the path supplied appended to {dir}, i.e. attempting to read data from {dir / fn}",
-                UserWarning,
-            )
+113
+114
def warn_if_path_supplied(fns: list[str], dir: Path) -> None:
+    """
+    Warns if the files in `fns` include directory separators.
+
+    Args:
+        fns: The list of files to check.
+        dir: The directory the files should exist in.
+
+    Warnings:
+        UserWarning: when the path to any of the files in `fns` includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.
+    """
+    for fn in fns:
+        if "/" in fn:
+            warnings.warn(
+                f"Using the path supplied appended to {dir}, i.e. attempting to read data from {dir / fn}",
+                UserWarning,
+            )
 
diff --git a/reference/common/strings/index.html b/reference/common/strings/index.html index 9ccfcfa2..9eb9793d 100644 --- a/reference/common/strings/index.html +++ b/reference/common/strings/index.html @@ -2559,8 +2559,7 @@

Source code in src/nhssynth/common/strings.py -
 6
- 7
+            
 7
  8
  9
 10
@@ -2575,23 +2574,24 @@ 

19 20 21 -22

def add_spaces_before_caps(string: str) -> str:
-    """
-    Adds spaces before capital letters in a string if there is a lower-case letter following it.
-
-    Args:
-        string: The string to add spaces to.
-
-    Returns:
-        The string with spaces added before capital letters.
-
-    Examples:
-        >>> add_spaces_before_caps("HelloWorld")
-        'Hello World'
-        >>> add_spaces_before_caps("HelloWorldAGAIN")
-        'Hello World AGAIN'
-    """
-    return " ".join(re.findall(r"[a-z]?[A-Z][a-z]+|[A-Z]+(?=[A-Z][a-z]|\b)", string))
+22
+23
def add_spaces_before_caps(string: str) -> str:
+    """
+    Adds spaces before capital letters in a string if there is a lower-case letter following it.
+
+    Args:
+        string: The string to add spaces to.
+
+    Returns:
+        The string with spaces added before capital letters.
+
+    Examples:
+        >>> add_spaces_before_caps("HelloWorld")
+        'Hello World'
+        >>> add_spaces_before_caps("HelloWorldAGAIN")
+        'Hello World AGAIN'
+    """
+    return " ".join(re.findall(r"[a-z]?[A-Z][a-z]+|[A-Z]+(?=[A-Z][a-z]|\b)", string))
 
@@ -2683,8 +2683,7 @@

Source code in src/nhssynth/common/strings.py -
25
-26
+            
26
 27
 28
 29
@@ -2706,30 +2705,31 @@ 

45 46 47 -48

def format_timedelta(start: float, finish: float) -> str:
-    """
-    Calculate and prettily format the difference between two calls to `time.time()`.
-
-    Args:
-        start: The start time.
-        finish: The finish time.
-
-    Returns:
-        A string containing the time difference in a human-readable format.
-    """
-    total = datetime.timedelta(seconds=finish - start)
-    hours, remainder = divmod(total.seconds, 3600)
-    minutes, seconds = divmod(remainder, 60)
-
-    if total.days > 0:
-        delta_str = f"{total.days}d {hours}h {minutes}m {seconds}s"
-    elif hours > 0:
-        delta_str = f"{hours}h {minutes}m {seconds}s"
-    elif minutes > 0:
-        delta_str = f"{minutes}m {seconds}s"
-    else:
-        delta_str = f"{seconds}s"
-    return delta_str
+48
+49
def format_timedelta(start: float, finish: float) -> str:
+    """
+    Calculate and prettily format the difference between two calls to `time.time()`.
+
+    Args:
+        start: The start time.
+        finish: The finish time.
+
+    Returns:
+        A string containing the time difference in a human-readable format.
+    """
+    total = datetime.timedelta(seconds=finish - start)
+    hours, remainder = divmod(total.seconds, 3600)
+    minutes, seconds = divmod(remainder, 60)
+
+    if total.days > 0:
+        delta_str = f"{total.days}d {hours}h {minutes}m {seconds}s"
+    elif hours > 0:
+        delta_str = f"{hours}h {minutes}m {seconds}s"
+    elif minutes > 0:
+        delta_str = f"{minutes}m {seconds}s"
+    else:
+        delta_str = f"{seconds}s"
+    return delta_str
 
diff --git a/reference/modules/dataloader/metadata/index.html b/reference/modules/dataloader/metadata/index.html index 6f26b598..539d63df 100644 --- a/reference/modules/dataloader/metadata/index.html +++ b/reference/modules/dataloader/metadata/index.html @@ -3083,79 +3083,79 @@

assembled_metadata = { "columns": { cn: { - "dtype": cmd.dtype.name - if not hasattr(cmd, "datetime_config") - else {"name": cmd.dtype.name, **cmd.datetime_config}, - "categorical": cmd.categorical, - } - for cn, cmd in self._metadata.items() - } - } - # We loop through the base dict above to add other parts if they are present in the metadata - for cn, cmd in self._metadata.items(): - if cmd.missingness_strategy: - assembled_metadata["columns"][cn]["missingness"] = ( - cmd.missingness_strategy.name - if cmd.missingness_strategy.name != "impute" - else {"name": cmd.missingness_strategy.name, "impute": cmd.missingness_strategy.impute} - ) - if cmd.transformer_config: - assembled_metadata["columns"][cn]["transformer"] = { - **cmd.transformer_config, - "name": cmd.transformer.__class__.__name__, - } - - # Add back the dropped_columns not present in the metadata - if self.dropped_columns: - assembled_metadata["columns"].update({cn: "drop" for cn in self.dropped_columns}) - - if collapse_yaml: - assembled_metadata = self._collapse(assembled_metadata) - - # We add the constraints section after all of the formatting and processing above - # In general, the constraints are kept the same as the input (provided they passed validation) - # If `collapse_yaml` is specified, we output the minimum set of equivalent constraints - if self.constraints: - assembled_metadata["constraints"] = ( - [str(c) for c in self.constraints.minimal_constraints] - if collapse_yaml - else self.constraints.raw_constraint_strings - ) - return assembled_metadata - - def save(self, path: pathlib.Path, collapse_yaml: bool) -> None: - """ - Writes metadata to a YAML file. - - Args: - path: The path at which to write the metadata YAML file. - collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication. - """ - with open(path, "w") as yaml_file: - yaml.safe_dump( - self._assemble(collapse_yaml), - yaml_file, - default_flow_style=False, - sort_keys=False, - ) - - def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]: - """ - Map combinations of our metadata implementation to SDV's as required by SDMetrics. - - Returns: - A dictionary containing the SDV metadata. - """ - sdv_metadata = { - "columns": { - cn: { - "sdtype": "boolean" - if cmd.boolean - else "categorical" - if cmd.categorical - else "datetime" - if cmd.dtype.kind == "M" - else "numerical", + "dtype": ( + cmd.dtype.name + if not hasattr(cmd, "datetime_config") + else {"name": cmd.dtype.name, **cmd.datetime_config} + ), + "categorical": cmd.categorical, + } + for cn, cmd in self._metadata.items() + } + } + # We loop through the base dict above to add other parts if they are present in the metadata + for cn, cmd in self._metadata.items(): + if cmd.missingness_strategy: + assembled_metadata["columns"][cn]["missingness"] = ( + cmd.missingness_strategy.name + if cmd.missingness_strategy.name != "impute" + else {"name": cmd.missingness_strategy.name, "impute": cmd.missingness_strategy.impute} + ) + if cmd.transformer_config: + assembled_metadata["columns"][cn]["transformer"] = { + **cmd.transformer_config, + "name": cmd.transformer.__class__.__name__, + } + + # Add back the dropped_columns not present in the metadata + if self.dropped_columns: + assembled_metadata["columns"].update({cn: "drop" for cn in self.dropped_columns}) + + if collapse_yaml: + assembled_metadata = self._collapse(assembled_metadata) + + # We add the constraints section after all of the formatting and processing above + # In general, the constraints are kept the same as the input (provided they passed validation) + # If `collapse_yaml` is specified, we output the minimum set of equivalent constraints + if self.constraints: + assembled_metadata["constraints"] = ( + [str(c) for c in self.constraints.minimal_constraints] + if collapse_yaml + else self.constraints.raw_constraint_strings + ) + return assembled_metadata + + def save(self, path: pathlib.Path, collapse_yaml: bool) -> None: + """ + Writes metadata to a YAML file. + + Args: + path: The path at which to write the metadata YAML file. + collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication. + """ + with open(path, "w") as yaml_file: + yaml.safe_dump( + self._assemble(collapse_yaml), + yaml_file, + default_flow_style=False, + sort_keys=False, + ) + + def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]: + """ + Map combinations of our metadata implementation to SDV's as required by SDMetrics. + + Returns: + A dictionary containing the SDV metadata. + """ + sdv_metadata = { + "columns": { + cn: { + "sdtype": ( + "boolean" + if cmd.boolean + else "categorical" if cmd.categorical else "datetime" if cmd.dtype.kind == "M" else "numerical" + ), } for cn, cmd in self._metadata.items() } @@ -3669,9 +3669,7 @@

Source code in src/nhssynth/modules/dataloader/metadata.py -
309
-310
-311
+            
311
 312
 313
 314
@@ -3693,23 +3691,21 @@ 

330 331 332 -333

def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:
-    """
-    Map combinations of our metadata implementation to SDV's as required by SDMetrics.
-
-    Returns:
-        A dictionary containing the SDV metadata.
-    """
-    sdv_metadata = {
-        "columns": {
-            cn: {
-                "sdtype": "boolean"
-                if cmd.boolean
-                else "categorical"
-                if cmd.categorical
-                else "datetime"
-                if cmd.dtype.kind == "M"
-                else "numerical",
+333
def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:
+    """
+    Map combinations of our metadata implementation to SDV's as required by SDMetrics.
+
+    Returns:
+        A dictionary containing the SDV metadata.
+    """
+    sdv_metadata = {
+        "columns": {
+            cn: {
+                "sdtype": (
+                    "boolean"
+                    if cmd.boolean
+                    else "categorical" if cmd.categorical else "datetime" if cmd.dtype.kind == "M" else "numerical"
+                ),
             }
             for cn, cmd in self._metadata.items()
         }
@@ -3785,9 +3781,7 @@ 

Source code in src/nhssynth/modules/dataloader/metadata.py -
293
-294
-295
+            
295
 296
 297
 298
@@ -3799,21 +3793,23 @@ 

304 305 306 -307

def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:
-    """
-    Writes metadata to a YAML file.
-
-    Args:
-        path: The path at which to write the metadata YAML file.
-        collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.
-    """
-    with open(path, "w") as yaml_file:
-        yaml.safe_dump(
-            self._assemble(collapse_yaml),
-            yaml_file,
-            default_flow_style=False,
-            sort_keys=False,
-        )
+307
+308
+309
def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:
+    """
+    Writes metadata to a YAML file.
+
+    Args:
+        path: The path at which to write the metadata YAML file.
+        collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.
+    """
+    with open(path, "w") as yaml_file:
+        yaml.safe_dump(
+            self._assemble(collapse_yaml),
+            yaml_file,
+            default_flow_style=False,
+            sort_keys=False,
+        )
 
diff --git a/search/search_index.json b/search/search_index.json index 5a04c0d3..db8c3be1 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"NHS Synth","text":"

This is a package for generating useful synthetic data, audited and assessed along the dimensions of utility, privacy and fairness. Currently, the main focus of the package in its beta stage is to experiment with different model architectures to find which are the most promising for real-world usage.

See the User Guide to get started with running an experiment with the package.

See the Development Guide and Code Reference to get started with contributing to the package.

"},{"location":"development_guide/","title":"Development guide","text":"

This document aims to provide a comprehensive set of instructions for continuing development of this package. Good knowledge of Python development is assumed. Some ways of working are subjective and preferential; as such we try to be as minimal in our proscription of other methods as possible.

"},{"location":"development_guide/#development-environment-setup","title":"Development environment setup","text":""},{"location":"development_guide/#python","title":"Python","text":"

The package currently supports major versions 3.9, 3.10 and 3.11 of Python. We recommend installing all of these versions; at minimum the latest supported version of Python should be used. Many people use pyenv for managing multiple python versions. On MacOS homebrew is a good, less invasive option for this (provided you then use a virtual environment manager too). For virtual environment management, we recommend Python's in-built venv functionality, but conda or some similar system would suffice (note that in the section below it may not be necessary to use any specific virtual environment management at all depending on the setup of Poetry).

"},{"location":"development_guide/#poetry","title":"Poetry","text":"

We use Poetry to manage dependencies and the actual packaging and publishing of NHSSynth to PyPI. Poetry is a more robust alternative to a requirements.txt file, allowing for grouped dependencies and advanced build options. Rather than freezing a specific pip state, Poetry only specifies the top-level dependencies and then handles the resolution and installation of the latest compatible versions of the full dependency tree per these top-level dependencies. See the pyproject.toml in the GitHub repository and Poetry's documentation for further context.

Once Poetry is installed (in your preferred way per the instructions on their website), you can choose one of two options:

  1. Allow poetry to control virtual environments in their proprietary way), such that when you install and develop the package poetry will automatically create a virtual environment for you.

  2. Change poetry's configuration to manage your own virtual environments:

    poetry config virtualenvs.create false\npoetry config virtualenvs.in-project false\n

    In this setup, a virtual environment can be be instantiated and activated in whichever way you prefer. For example, using venv:

    python3.11 -m venv nhssynth-3.11\nsource nhssynth-3.11/bin/activate\n
"},{"location":"development_guide/#package-installation","title":"Package installation","text":"

At this point, the project dependencies can be installed via poetry install --with dev (add optional flags: --with aux to work with the auxiliary notebooks, --with docs to work with the documentation). This will install the package in editable mode, meaning that changes to the source code will be reflected in the installed package without needing to reinstall it. Note that if you are using your own virtual environment, you will need to activate it before running this command.

You can then interact with the package in one of two ways:

  1. Via the CLI module, which is accessed using the nhssynth command, e.g.

    poetry run nhssynth ...\n

    Note that you can omit the poetry run part and just type nhssynth if you followed the optional steps above to manage and activate your own virtual environment, or if you have executed poetry shell beforehand. 2. Through directly importing parts of the package to use in an existing project (from nhssynth.modules... import ...).

"},{"location":"development_guide/#secure-mode","title":"Secure mode","text":"

Note that in order to train a generator in secure mode (see the documentation for details) the PyTorch extension package csprng must be installed separately. Currently this package's dependencies are not compatible with recent versions of PyTorch (the author's plan on rectifying this - watch this space), so you will need to install it manually, you can do this in your environment by running:

git clone git@github.com:pytorch/csprng.git\ncd csprng\ngit branch release \"v0.2.2-rc1\"\ngit checkout release\npython setup.py install\n
"},{"location":"development_guide/#coding-practices","title":"Coding practices","text":""},{"location":"development_guide/#style","title":"Style","text":"

We use black for code formatting. This is a fairly opinionated formatter, but it is widely used and has a good reputation. We also use ruff to manage imports and lint the code. Both of these tools are run automatically via pre-commit hooks. Ensure you have installed the package with the dev group of dependencies and then run the following command to install the hooks:

pre-commit install\n

Note that you may need to pre-pend this command with poetry run if you are not using your own virtual environment.

This will ensure that your code conforms to the two formatters' / linters' requirements each time you commit to a branch. black and ruff are also run as part of the CI workflow discussed below, such that even without these hooks, the code will be checked and raise an error on GitHub if it is not formatted consistently.

Configuration for both packages can be found in the pyproject.toml, this configuration should be picked up automatically by both the pre-commit hooks and your IDE / running them manually in the command line. The main configuration is as follows:

[tool.black]\nline-length = 120\n\n[tool.ruff]\ninclude = [\"*.py\", \"*.pyi\", \"**/pyproject.toml\", \"*.ipynb\"]\nselect = [\"E4\", \"E7\", \"E9\", \"F\", \"C90\", \"I\"]\n\n[tool.ruff.per-file-ignores]\n\"src/nhssynth/common/constants.py\" = [\"F403\", \"F405\"]\n\n[tool.ruff.isort]\nknown-first-party = [\"nhssynth\"]\n

This ensure that absolute imports from NHSSynth are sorted separately from the rest of the imports in a file.

There are a number of other hooks used as part of this repositories pre-commit, including one that automatically mirrors the poetry version of these packages in the dev per the list of supported packages and .poetry-sync-db.json. Roughly, these other hooks ensure correct formatting of .yaml and .toml files, checks for large files being added to a commit, strips notebook output from the files, and fixes whitespace and end-of-file issues. These are mostly consistent with the NHSx analytics project template's hooks

"},{"location":"development_guide/#documentation","title":"Documentation","text":"

There should be Google-style docstrings on all non-trivial functions and classes. Ideally a docstring should take the form:

def func(arg1: type1, arg2: type2) -> returntype:\n    \"\"\"\n    One-line summary of the function.\n    AND / OR\n    Longer description of the function, including any caveats or assumptions where appropriate.\n\n    Args:\n        arg1: Description of arg1.\n        arg2: Description of arg2.\n\n    Returns:\n        Description of the return value.\n    \"\"\"\n    ...\n

These docstrings are then compiled into a full API documentation tree as part of a larger MkDocs documentation site hosted via GitHub (the one you are reading right now!). This process is derived from this tutorial.

The MkDocs page is built using the mkdocs-material theme. The documentation is built and hosted automatically via GitHub Pages.

The other parts of this site comprise markdown documents in the docs folder. Adding new pages is handled in the mkdocs.yml file as in any other Material MkDocs site. See their documentation if more complex changes to the site are required.

"},{"location":"development_guide/#testing","title":"Testing","text":"

We use tox to manage the execution of tests for the package against multiple versions of Python, and to ensure that they are being run in a clean environment. To run the tests, simply execute tox in the root directory of the repository. This will run the tests against all supported versions of Python. To run the tests against a specific version of Python, use tox -e py311 (or py310 or py39).

"},{"location":"development_guide/#configuration","title":"Configuration","text":"

See the tox.ini file for more information on the testing configuration. We follow the Poetry documentation on tox support to ensure that for each version of Python, tox will create an sdist package of the project and use pip to install it in a fresh environment. Thus, dependencies are resolved by pip in the first place and then afterwards updated to the locked dependencies in poetry.lock by running poetry install ... in this fresh environment. The tests are then run using poetry pytest, which is configured in the pyproject.toml file. This configuration is fairly minimal: simply specifying the testing directory as the tests folder and filtering some known warnings.

[tool.pytest.ini_options]\ntestpaths = \"tests\"\nfilterwarnings = [\"ignore::DeprecationWarning:pkg_resources\"]\n

We can also use coverage to check the test coverage of the package. This is configured in the pyproject.toml file as follows:

[tool.coverage.run]\nsource = [\"src/nhssynth/cli\", \"src/nhssynth/common\", \"src/nhssynth/modules\"]\nomit = [\n    \"src/nhssynth/common/debugging.py\",\n]\n

We omit debugging.py as it is a wrapper for reading full trace-backs of warnings and not to be imported directly.

"},{"location":"development_guide/#adding-tests","title":"Adding Tests","text":"

We use the pytest framework for testing. The testing directory structure mirrors that of src. The usual testing practices apply.

"},{"location":"development_guide/#releases","title":"Releases","text":""},{"location":"development_guide/#version-management","title":"Version management","text":"

The package's version should be updated following the semantic versioning framework. The package is currently in a pre-release state, such that major version 1.0.0 should only be tagged once the package is functionally complete and stable.

To update the package's metadata, we can use Poetry's version command:

poetry version <version>\n

We can then commit and push the changes to the version file, and create a new tag:

git add pyproject.toml\ngit commit -m \"Bump version to <version>\"\ngit push\n

We should then tag the release using GitHub's CLI (or manually via git if you prefer):

gh release create <version> --generate-notes\n

This will create a new release on GitHub, and will automatically generate a changelog based on the commit messages and PR's closed since the last release. This changelog can then be edited to add more detail if necessary.

"},{"location":"development_guide/#building-and-publishing-to-pypi","title":"Building and publishing to PyPI","text":"

Poetry offers not only dependency management, but also a simple way to build and distribute the package.

After tagging a release per the section above, we can build the package using Poetry's build command:

poetry build\n

This will create a dist folder containing the built package. To publish this to PyPI, we can use the publish command:

poetry publish\n

This will prompt for PyPI credentials, and then publish the package. Note that this will only work if you have been added as a Maintainer of the package on PyPI.

It might be preferable at some point in the future to set up Trusted Publisher Management via OpenID Connect (OIDC) to allow for automated publishing of the package via a GitHub workflow. See the \"Publishing\" tab of NHSSynth's project management panel on PyPI to set this up.

"},{"location":"development_guide/#github","title":"GitHub","text":""},{"location":"development_guide/#continuous-integration","title":"Continuous integration","text":"

We use GitHub Actions for continuous integration. The different workflows comprising this can be found in the .github/workflows folder. In general, the CI workflow is triggered on every push to the main or a feature branch - as appropriate - and runs tests against all supported versions of Python. It also runs black and ruff to check that the code is formatted correctly, and builds the documentation site.

There are also scripts to update the dynamic badges in the README. These work via a gist associated with the repository. It is not easy to transfer ownership of this process, so if they break please feel free to contact me.

"},{"location":"development_guide/#branching","title":"Branching","text":"

We encourage the use of the Gitflow branching model for development. This means that the main branch is always in a stable state, and that all development work is done on feature branches. These feature branches are then merged into main via pull requests. The main branch is protected, such that pull requests must be reviewed and approved before they can be merged.

At minimum, the main branches protection should be maintained, and roughly one branch per issue should be used. Ensure that all of the CI checks pass before merging.

"},{"location":"development_guide/#security-and-vulnerability-management","title":"Security and vulnerability management","text":"

The GitHub repository for the package has Dependabot, code scanning, and other security features enabled. These should be monitored continuously and any issues resolved as soon as possible. When issues of this type require a specific version of a dependency to be specified (and it is one that is not already amongst the dependency groups of the package), the version should be referenced as part of the security group of dependencies (i.e. with poetry add <package> --group security) and a new release created (see above).

"},{"location":"downstream_tasks/","title":"Defining a downstream task","text":"

It is likely that a synthetic dataset may be associated with specific modelling efforts or metrics that are not included in the general suite of evaluation tools supported more explicitly by this package. Additionally, analyses on model outputs for bias and fairness provided via Aequitas require some basis of predictions on which to perform the analysis. For these reasons, we provide a simple interface for defining a custom downstream task.

All downstream tasks are to be located in a folder named tasks in the working directory of the project, with subfolders for each dataset, i.e. the tasks associated with the support dataset should be located in the tasks/support directory.

The interface is then quite simple:

  • There should be a function called run that takes a single argument: dataset (additional arguments could be provided with some further configuration if there is a need for this)
  • The run function should fit a model and / or calculate some metric(s) on the dataset.
  • It should then return predicted probabilities for the outcome variable(s) in the dataset and a dictionary of metrics.
  • The file should contain a top-level variable containing an instantiation of the nhssynth Task class.

See the example below of a logistic regression model fit on the support dataset with the event variable as the outcome and rocauc as the metric of interest:

import pandas as pd\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import roc_auc_score\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.preprocessing import StandardScaler\n\nfrom nhssynth.modules.evaluation.tasks import Task\n\n\ndef run(dataset: pd.DataFrame) -> tuple[pd.DataFrame, dict]:\n    # Split the dataset into features and target\n    target = \"event\"\n\n    data = dataset.dropna()\n    X, y = data.drop([\"dob\", \"x3\", target], axis=1), data[target]\n    X_train, X_test, y_train, y_test = train_test_split(\n        StandardScaler().fit_transform(X), y, test_size=0.33, random_state=42\n    )\n\n    lr = LogisticRegression()\n    lr.fit(X_train, y_train)\n\n    # Get the predicted probabilities and predictions\n    probs = pd.DataFrame(lr.predict_proba(X_test)[:, 1], columns=[f\"lr_{target}_prob\"])\n\n    rocauc = roc_auc_score(y_test, probs)\n\n    return probs, {\"rocauc_lr\": rocauc}\n\n\ntask = Task(\"Logistic Regression on 'event'\", run, supports_aequitas=True)\n

Note the highlighted lines above:

  1. The Task class has been imported from nhssynth.modules.evaluations.tasks
  2. The run function should accept one argument and return a tuple
  3. The second element of this tuple should be a dictionary labelling each metric of interest (this name will be used in the dashboard as identification so ensure it is unique to the experiment)
  4. The task should be instantiated with a name, the run function and a boolean indicating whether the task supports Aequitas analysis, if the task does not support Aequitas analysis, then the first element of the tuple will not be used and None can be returned instead.

The rest of this file can contain any arbitrary code that runs within these constraints, this could be a simple model as above, or a more complex pipeline of transformations and models to match a pre-existing workflow.

"},{"location":"getting_started/","title":"Getting Started","text":""},{"location":"getting_started/#running-an-experiment","title":"Running an experiment","text":"

This package offers two easy ways to run reproducible and highly-configurable experiments. The following sections describe how to use each of these two methods.

"},{"location":"getting_started/#via-the-cli","title":"Via the CLI","text":"

The CLI is the easiest way to quickly run an experiment. It is designed to be as simple as possible, whilst still offering a high degree of configurability. An example command to run a full pipeline experiment is:

nhssynth pipeline \\\n    --experiment-name test \\\n    --dataset support \\\n    --seed 123 \\\n    --architecture DPVAE PATEGAN DECAF \\\n    --repeats 3 \\\n    --downstream-tasks \\\n    --column-similarity-metrics CorrelationSimilarity ContingencySimilarity \\\n    --column-shape-metrics KSComplement TVComplement \\\n    --boundary-metrics BoundaryAdherence \\\n    --synthesis-metrics NewRowSynthesis \\\n    --divergence-metrics ContinuousKLDivergence DiscreteKLDivergence\n

This will run a full pipeline experiment on the support dataset in the data directory. The outputs of the experiment will be recorded in a folder named test (corresponding to the experiment name) in the experiments directory.

In total, three different model architectures will be trained three times each with their default configurations. The resulting generated synthetic datasets will be evaluated via the downstream tasks in tasks/support alongside the metrics specified in the command. A dashboard will then be built automatically to exhibit the results.

The components of the run are persistent to the experiment's folder. Suppose you have already run this experiment and want to add some new evaluations. You do not have to re-run the entire experiment, you can simply run:

nhssynth evaluation -e test -d support -s 123 --coverage-metrics RangeCoverage CategoryCoverage\nnhssynth dashboard -e test -d support\n

This will regenerate the dashboard with a different set of metrics corresponding to the arguments passed to evaluation. Note that the --experiment-name and --dataset arguments are required for all commands, as they are used to identify the experiment and ensure reproducibility.

"},{"location":"getting_started/#via-a-configuration-file","title":"Via a configuration file","text":"

A yaml configuration file placed in the config folder can be used to get the same result as the above:

seed: 123\nexperiment_name: test\nrun_type: pipeline\nmodel:\n  architecture:\n    - DPVAE\n    - DPGAN\n    - DECAF\n  max_grad_norm: 5.0\n  secure_mode: false\n  repeats: 4\nevaluation:\n  downstream_tasks: true\n  column_shape_metrics:\n  - KSComplement\n  - TVComplement\n  column_similarity_metrics:\n  - CorrelationSimilarity\n  - ContingencySimilarity\n  boundary_metrics:\n  - BoundaryAdherence\n  synthesis_metrics:\n  - NewRowSynthesis\n  divergence_metrics:\n  - ContinuousKLDivergence\n  - DiscreteKLDivergence\n

Once saved as run_pipeline.yaml in the config directory, the package can be run under the configuration laid out in the file via:

nhssynth config -c run_pipeline\n

Note that if you run via the CLI, you can add the --save-config flag to your command to save the configuration file in the experiments/test (or whatever the --experiment-name has been set to) directory. This allows for easy reproduction of an experiment at a later date or on someone else's computer through sharing the configuration file with them.

"},{"location":"getting_started/#setting-up-a-datasets-metadata","title":"Setting up a dataset's metadata","text":"

For each dataset you wish to work with, it is advisable to setup a corresponding metadata file. The package will infer this when information is missing (and you can then tweak it). The reason we suggest specifying metadata in this way is because Pandas / Python are in general bad at interpreting CSV files, particularly the specifics of datatypes, date objects and so on.

To do this, we must create a metadata yaml file in the dataset's directory. For example, for the support dataset, this file is located at data/support_metadata.yaml. By default, the package will look for a file with the same name as the dataset in the dataset's directory, but with _metadata appended to the end. This is configurable like most other filenaming conventions via the CLI.

The metadata file is split into two sections: columns and constraints. The former specifies the nature of each column in the dataset, whilst the latter specifies any constraints that should be enforced on the dataset.

"},{"location":"getting_started/#column-metadata","title":"Column metadata","text":"

Again, we refer to the support dataset's metadata file as an example:

columns:\n  dob:\n    dtype:\n      name: datetime64\n      floor: S\n  x1:\n    categorical: true\n    dtype: int64\n  x2:\n    categorical: true\n    dtype: int64\n  x3:\n    categorical: true\n  x4:\n    categorical: true\n    dtype: int64\n  x5:\n    categorical: true\n    dtype: int64\n  x6:\n    categorical: true\n    dtype: int64\n  x7:\n    dtype: int64\n  x8:\n    dtype: float64\n    missingness:\n      impute: mean\n  x9:\n    dtype: int64\n  x10:\n    dtype:\n      name: float64\n      rounding_scheme: 0.1\n  x11:\n    dtype: int64\n  x12:\n    dtype: float64\n  x13:\n    dtype: float64\n  x14:\n    dtype: float64\n  duration:\n    dtype: int64\n  event:\n    categorical: true\n    dtype: int64\n

For each column in the dataset, we specify the following:

  • It's dtype, this can be any numpy data type or a datetime type.
  • In the case of a datetime type, we also specify the floor (i.e. the smallest unit of time that we care about). In general this should be set to match the smallest unit of time in the dataset.
  • In the case of a float type, we can also specify a rounding_scheme to round the values to a certain number of decimal places, again this should be set according to the rounding applied to the column in the real data, or if you want to round the values for some other reason.
  • Whether it is categorical or not. If a column is not categorical, you don't need to specify this. A column is inferred as categorical if it has less than 10 unique values or is a string type.
  • If the column has missing values, we can specify how to deal with them by specifying a missingness strategy. In the case of the x8 column, we impute the missing values with the column's mean. If you don't specify this, the CLI or configuration file's specified global missingness strategy will be applied instead (this defaults to the augment strategy which model's the missingness as a separate level in the case of categorical features, or as a separate cluster in the case of continuous features).
"},{"location":"getting_started/#constraints","title":"Constraints","text":"

The second part of the metadata file specifies any constraints that should be enforced on the dataset. These can be a relative constraint between two columns, or a fixed one via a constant on a single column. For example, the support dataset's constraints are as follows (note that these are arbitrarily defined and do not necessarily reflect the real data):

constraints:\n  - \"x10 in (0,100)\"\n  - \"x12 in (0,100)\"\n  - \"x13 in (0,100)\"\n  - \"x10 <= x12\"\n  - \"x12 < x13\"\n  - \"x10 < x13\"\n  - \"x8 > x10\"\n  - \"x8 > x12\"\n  - \"x8 > x13\"\n  - \"x11 > 100\"\n  - \"x12 > 10\"\n

The function of these constraints is fairly self-explanatory: The package ensures the constraints are feasible and minimises them before applying transformations to ensure that they will be satisfied in the synthetic data as well. When a column does not meet a feasible constraint in the real data, we assume that this is intentional and use the violation as a feature upon which to generate synthetic data that also violates the constraint.

There is a further constraint fixcombo that only applies to categorical columns. This suggests that only existing combinations of two or more categorical columns should be generated, i.e. the columns can be collapsed into a single composite feature. I.e. if we have a column for pregnancy, and another for sex, we may only want to allow three categories, 'male:not-pregnant', 'female:pregnant', 'female:not-pregnant'. This is specified as follows:

constraints:\n  - \"pregnancy fixcombo sex\"\n

In conclusion then, we support the following constraint types:

  • fixcombo for categorical columns
  • < and < for non-categorical columns
  • >= and <= for non-categorical columns
  • in for non-categorical columns, which is effectively two of the above constraints combined. I.e. x in [a, b) is equivalent to x >= a and x < b. This is purely a UX feature and is treated as two separate constraints internally.

Once this metadata is setup, you are ready to run your experiment.

"},{"location":"getting_started/#evaluation","title":"Evaluation","text":"

Once models have been trained and synthetic datasets generated, we leverage evaluations from SDMetrics, Aequitas, the NHS' internal SynAdvSuite (at current time you must request access to this repository to use the privacy-related attacks it implements), and also offer a facility for the custom specification of downstream tasks. These evaluations are then aggregated into a dashboard for ease of comparison and analysis.

See the relevant documentation for each of these packages for more information on the metrics they offer.

"},{"location":"model_card/","title":"Model Card: Variational AutoEncoder with Differential Privacy","text":""},{"location":"model_card/#model-details","title":"Model Details","text":"

The implementation of the Variational AutoEncoder (VAE) with Differential Privacy within this repository is based on work done by Dominic Danks during an NHSX Analytics Unit PhD internship (last commit to the original SynthVAE repository: commit 88a4bdf). This model card describes an updated and extended version of the model, by Harrison Wilde. Further information about the previous version created by Dom and its model implementation can be found in Section 5.4 of the associated report.

"},{"location":"model_card/#model-use","title":"Model Use","text":""},{"location":"model_card/#intended-use","title":"Intended Use","text":"

This model is intended for use in experimenting with differential privacy and VAEs.

"},{"location":"model_card/#training-data","title":"Training Data","text":"

Experiments in this repository are run against the Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) dataset accessed via the pycox python library. We also performed further analysis on a single table that we extracted from MIMIC-III.

"},{"location":"model_card/#performance-and-limitations","title":"Performance and Limitations","text":"

A from-scratch VAE implementation was compared against various models available within the SDV framework using a variety of quality and privacy metrics on the SUPPORT dataset. The VAE was found to be competitive with all of these models across the various metrics. Differential Privacy (DP) was introduced via DP-SGD and the performance of the VAE for different levels of privacy was evaluated. It was found that as the level of Differential Privacy introduced by DP-SGD was increased, it became easier to distinguish between synthetic and real data.

Proper evaluation of quality and privacy of synthetic data is challenging. In this work, we utilised metrics from the SDV library due to their natural integration with the rest of the codebase. A valuable extension of this work would be to apply a variety of external metrics, including more advanced adversarial attacks to more thoroughly evaluate the privacy of the considered methods, including as the level of DP is varied. It would also be of interest to apply DP-SGD and/or PATE to all of the considered methods and evaluate whether the performance drop as a function of implemented privacy is similar or different across the models.

Currently the SynthVAE model only works for data which is 'clean'. I.e data that has no missingness or NaNs within its input. It can handle continuous, categorical and datetime variables. Special types such as nominal data cannot be handled properly however the model may still run. Column names have to be specified in the code for the variable group they belong to.

Hyperparameter tuning of the model can result in errors if certain parameter values are selected. Most commonly, changing learning rate in our example results in errors during training. An extensive test to evaluate plausible ranges has not been performed as of yet. If you get errors during tuning then consider your hyperparameter values and adjust accordingly.

"},{"location":"model_card/#acknowledgements","title":"Acknowledgements","text":"

This documentation is inspired by Model Cards for Model Reporting (Mitchell et al.) and Lessons from Archives (Jo & Gebru).

"},{"location":"models/","title":"Adding new models","text":"

The model module contains all of the architectures implemented as part of this package. We offer GAN and VAE based architectures with a number of adjustments to achieve privacy and other augmented functionalities. The module handles the training and generation of synthetic data using these architectures, per a user's choice of model(s) and configuration.

It is likely that as the literature matures, more effective architectures will present themselves as promising for application to the type of tabular data NHSSynth is designed for. Below we discuss how to add new models to the package.

"},{"location":"models/#model-design","title":"Model design","text":"

The models in this package are built entirely in PyTorch and use Opacus for differential privacy.

We have built the VAE and (Tabular)GAN implementations in this package to serve as the foundations for a number of other architectures. As such, we try to maintain a somewhat modular design to building up more complex differentially private (or otherwise augmented) architectures. Each model inherits from either the GAN or VAE class (in files of the same name) which in turn inherit from a generic Model class found in the common folder. This folder contains components of models which are not to be instantiated themselves, e.g. a mixin class for differential privacy, the MLP underlying the GAN and so on.

The Model class from which all of the models derive handles all of the general attributes. Roughly, these are the specifics of the dataset the instance of the model is relative to, the device that training is to be carried out upon, and other training parameters such as the total number of epochs to execute.

We define these things at the model level, as when using differential privacy or other privacy accountant methods, we must know ahead of time the data and length of training exposure in order to calculate the levels of noise required to reach a certain privacy guarantee and so on.

"},{"location":"models/#implementing-a-new-model","title":"Implementing a new model","text":"

In order to add a new architecture then, it is important to first investigate the modular parts already implemented to ensure that what you want to build is not already possible through the composition of these existing parts. Then you must ensure that your architecture either inherits from the GAN or VAE, or Model if you wish to implement a different type of generative model.

In all of these cases, the interface expects for the implementation to have the following methods:

  • get_args: a class method that lists the architecture specific arguments that the model requires. This is used to facilitate default arguments in the python API whilst still allowing for arguments in the CLI to be propagated and recorded automatically in the experiment output. This should be a list of variable names equal to the concatenation of all of the non-Model parent classes (e.g. DPVAE has DP and VAE args) plus any architecture specific arguments in the __init__ method of the model in question.
  • get_metrics: another class method that behaves similarly to the above, should return a list of valid metrics to track during training for this model
  • train: a method handling the training loop for the model. This should take num_epochs, patience and displayed_metrics as arguments and return a tuple containing the number of epochs that were executed plus a bundle of training metrics (the values over time returned by get_metrics on the class). In the execution of this method, the utility methods defined in Model should be called in order, _start_training at the beginning, then _record_metrics at each training step of the data loader, and finally _finish_training to clean up progress bars and so on. displayed_metrics determines which metrics are actively displayed during training.
  • generate: a method to call on the trained model which generates N samples of data, and calls the model's associated MetaTransformer to return a valid pandas DataFrame of synthetic data ready to output.
"},{"location":"models/#adding-a-new-model-to-the-cli","title":"Adding a new model to the CLI","text":"

Once you have implemented your new model, you must add it to the CLI. To do this, we must first export the model's class into the MODELS constant in the __init__ file in the models subfolder. We can then add a new function and option in module_arguments.py to list the arguments and their types unique to this type of architecture.

Note that you should not duplicate arguments that are already defined in the Model class or foundational model architectures such as the GAN if you are implementing an extension to it. If you have setup get_args correctly all of this will be propagated automatically.

"},{"location":"modules/","title":"Adding new modules","text":"

The package is designed such that each module can be used as part of a pipeline (via the CLI or a configuration file) or independently (via importing them into an existing codebase).

In the future it may be desireable to add / adjust the modules of the package, this guide offers a high-level overview of how to do so.

"},{"location":"modules/#importing-a-module-from-this-package","title":"Importing a module from this package","text":"

After installing the package, you can simply do:

from nhssynth.modules import <module>\n
and you will be able to use it in your code!

"},{"location":"modules/#creating-a-new-module-and-folding-it-into-the-cli","title":"Creating a new module and folding it into the CLI","text":"

The following instructions specify how to extend this package with a new module:

  1. Create a folder for your module within the package, i.e. src/nhssynth/modules/mymodule
  2. Include within it a main executor function that accepts arguments from the CLI, i.e.

    def myexecutor(args):\n    ...\n

    In mymodule/executor.py and export it by adding from .executor import myexecutor to mymodule/__init__.py. Check the existing modules for examples of what a typical executor function looks like.

  3. In the cli folder, add a corresponding function to module_arguments.py and populate with arguments you want to expose in the CLI:

    def add_mymodule_args(parser: argparse.ArgumentParser, group_title: str, overrides=False):\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(...)\n    group.add_argument(...)\n    ...\n
  4. Next, in module_setup.py make the following adjustments to the MODULE_MAP code:

    MODULE_MAP = {\n    ...\n    \"mymodule\": ModuleConfig(\n        func=m.mymodule.myexecutor,\n        add_args=ma.add_mymodule_args,\n        description=\"...\",\n        help=\"...\",\n        common_parsers=[...]\n    ),\n    ...\n}\n

    Where common_parsers is a subset of COMMON_PARSERS defined in common_arguments.py. Note that the \"seed\" and \"core\" parsers are added automatically, so you don't need to specify them. These parsers can be used to add arguments to your module that are common to multiple modules, e.g. the dataloader and evaluation modules both use --typed to specify the path of the typed input dataset.

  5. You can (optionally) also edit the following block if you want your module to be included in a full pipeline run:

    PIPELINE = [..., mymodule, ...]  # NOTE this determines the order of a pipeline run\n
  6. Congrats, your module is implemented within the CLI, its documentation etc. will now be built automatically and it can be referenced in configuration files!

"},{"location":"secure_mode/","title":"Opacus' secure mode","text":"

Part of the process for achieving a differential privacy guarantee under Opacus involves generating noise according to a Gaussian distribution with mean 0 in Opacus' _generate_noise() function.

Enabling secure_mode when using the NHSSynth package ensures that the generated noise is also secure against floating point representation attacks, such as the ones in https://arxiv.org/abs/2107.10138 and https://arxiv.org/abs/2112.05307.

This attack first appeared in https://arxiv.org/abs/2112.05307; the fix via the csprng package is based on https://arxiv.org/abs/2107.10138 and involves calling the Gaussian noise function $2n$ times, where $n=2$ (see section 5.1 in https://arxiv.org/abs/2107.10138).

The reason for choosing $n=2$ is that $n$ can be any number greater than $1$. The bigger $n$ is, though, the more computation needs to be done to generate the Gaussian samples. The choice of $n=2$ is justified via the knowledge that the attack has a complexity of $2^{p(2n-1)}$. In PyTorch, $p=53$ and so the complexity is $2^159$, which is deemed sufficiently hard for an attacker to break.

"},{"location":"reference/SUMMARY/","title":"SUMMARY","text":"
  • cli
    • common_arguments
    • config
    • model_arguments
    • module_arguments
    • module_setup
    • run
  • common
    • common
    • constants
    • debugging
    • dicts
    • io
    • strings
  • modules
    • dashboard
      • Upload
      • io
      • pages
        • 1_Tables
        • 2_Plots
        • 3_Experiment_Configurations
      • run
      • utils
    • dataloader
      • constraints
      • io
      • metadata
      • metatransformer
      • missingness
      • run
      • transformers
        • base
        • categorical
        • continuous
        • datetime
    • evaluation
      • aequitas
      • io
      • metrics
      • run
      • tasks
      • utils
    • model
      • common
        • dp
        • mlp
        • model
      • io
      • models
        • dpvae
        • gan
        • vae
      • run
      • utils
    • plotting
      • io
      • plots
      • run
    • structure
      • run
"},{"location":"reference/cli/","title":"cli","text":""},{"location":"reference/cli/common_arguments/","title":"common_arguments","text":"

Functions to define the CLI's \"common\" arguments, i.e. those that can be applied to either: - All module argument lists, e.g. --dataset, --seed, etc. - A subset of module(s) argument lists, e.g. --synthetic, --typed, etc.

"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.get_core_parser","title":"get_core_parser(overrides=False)","text":"

Create the core common parser group applied to all modules (and the pipeline and config options). Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.

Parameters:

Name Type Description Default overrides

whether the arguments declared within are required or not.

False

Returns:

Type Description ArgumentParser

The parser with the group containing the core arguments attached.

Source code in src/nhssynth/cli/common_arguments.py
def get_core_parser(overrides=False) -> argparse.ArgumentParser:\n    \"\"\"\n    Create the core common parser group applied to all modules (and the `pipeline` and `config` options).\n    Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.\n\n    Args:\n        overrides: whether the arguments declared within are required or not.\n\n    Returns:\n        The parser with the group containing the core arguments attached.\n    \"\"\"\n    \"\"\"\"\"\"\n    core = argparse.ArgumentParser(add_help=False)\n    core_grp = core.add_argument_group(title=\"options\")\n    core_grp.add_argument(\n        \"-d\",\n        \"--dataset\",\n        required=(not overrides),\n        type=str,\n        help=\"the name of the dataset to experiment with, should be present in `<DATA_DIR>`\",\n    )\n    core_grp.add_argument(\n        \"-e\",\n        \"--experiment-name\",\n        type=str,\n        default=TIME,\n        help=\"name the experiment run to affect logging, config, and default-behaviour i/o\",\n    )\n    core_grp.add_argument(\n        \"--save-config\",\n        action=\"store_true\",\n        help=\"save the config provided via the cli, this is a recommended option for reproducibility\",\n    )\n    return core\n
"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.get_seed_parser","title":"get_seed_parser(overrides=False)","text":"

Create the common parser group for the seed. NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.

Parameters:

Name Type Description Default overrides

whether the arguments declared within are required or not.

False

Returns:

Type Description ArgumentParser

The parser with the group containing the seed argument attached.

Source code in src/nhssynth/cli/common_arguments.py
def get_seed_parser(overrides=False) -> argparse.ArgumentParser:\n    \"\"\"\n    Create the common parser group for the seed.\n    NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.\n\n    Args:\n        overrides: whether the arguments declared within are required or not.\n\n    Returns:\n        The parser with the group containing the seed argument attached.\n    \"\"\"\n    parser = argparse.ArgumentParser(add_help=False)\n    parser_grp = parser.add_argument_group(title=\"options\")\n    parser_grp.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        help=\"specify a seed for reproducibility, this is a recommended option for reproducibility\",\n    )\n    return parser\n
"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.suffix_parser_generator","title":"suffix_parser_generator(name, help, required=False)","text":"

Generator function for creating parsers following a common template. These parsers are all suffixes to the --dataset / -d / DATASET argument, see COMMON_TITLE.

Parameters:

Name Type Description Default name str

the name / label of the argument to add to the CLI options.

required help str

the help message when the CLI is run with --help / -h.

required required bool

whether the argument must be provided or not.

False Source code in src/nhssynth/cli/common_arguments.py
def suffix_parser_generator(name: str, help: str, required: bool = False) -> argparse.ArgumentParser:\n    \"\"\"Generator function for creating parsers following a common template.\n    These parsers are all suffixes to the --dataset / -d / DATASET argument, see `COMMON_TITLE`.\n\n    Args:\n        name: the name / label of the argument to add to the CLI options.\n        help: the help message when the CLI is run with --help / -h.\n        required: whether the argument must be provided or not.\n    \"\"\"\n\n    def get_parser(overrides: bool = False) -> argparse.ArgumentParser:\n        parser = argparse.ArgumentParser(add_help=False)\n        parser_grp = parser.add_argument_group(title=COMMON_TITLE)\n        parser_grp.add_argument(\n            f\"--{name.replace('_', '-')}\",\n            required=required and not overrides,\n            type=str,\n            default=f\"_{name}\",\n            help=help,\n        )\n        return parser\n\n    return get_parser\n
"},{"location":"reference/cli/config/","title":"config","text":"

Read, write and process config files, including handling of module-specific / common config overrides.

"},{"location":"reference/cli/config/#nhssynth.cli.config.assemble_config","title":"assemble_config(args, all_subparsers)","text":"

Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.

Parameters:

Name Type Description Default args Namespace

A namespace object containing all parsed command-line arguments.

required all_subparsers dict[str, ArgumentParser]

A dictionary mapping module names to subparser objects.

required

Returns:

Type Description dict[str, Any]

A dictionary containing configuration information extracted from args in a module-wise nested format that is YAML-friendly.

Raises:

Type Description ValueError

If a module specified in args.modules_to_run is not in all_subparsers.

Source code in src/nhssynth/cli/config.py
def assemble_config(\n    args: argparse.Namespace,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> dict[str, Any]:\n    \"\"\"\n    Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.\n\n    Args:\n        args: A namespace object containing all parsed command-line arguments.\n        all_subparsers: A dictionary mapping module names to subparser objects.\n\n    Returns:\n        A dictionary containing configuration information extracted from `args` in a module-wise nested format that is YAML-friendly.\n\n    Raises:\n        ValueError: If a module specified in `args.modules_to_run` is not in `all_subparsers`.\n    \"\"\"\n    args_dict = vars(args)\n\n    # Filter out the keys that are not relevant to the config file\n    args_dict = filter_dict(\n        args_dict, {\"func\", \"experiment_name\", \"save_config\", \"save_config_path\", \"module_handover\"}\n    )\n    for k in args_dict.copy().keys():\n        # Remove empty metric lists from the config\n        if \"_metrics\" in k and not args_dict[k]:\n            args_dict.pop(k)\n\n    modules_to_run = args_dict.pop(\"modules_to_run\")\n    if len(modules_to_run) == 1:\n        run_type = modules_to_run[0]\n    elif modules_to_run == PIPELINE:\n        run_type = \"pipeline\"\n    else:\n        raise ValueError(f\"Invalid value for `modules_to_run`: {modules_to_run}\")\n\n    # Generate a dictionary containing each module's name from the run, with all of its possible corresponding config args\n    module_args = {\n        module_name: [action.dest for action in all_subparsers[module_name]._actions if action.dest != \"help\"]\n        for module_name in modules_to_run\n    }\n\n    # Use the flat namespace to populate a nested (by module) dictionary of config args and values\n    out_dict = {}\n    for module_name in modules_to_run:\n        for k in args_dict.copy().keys():\n            # We want to keep dataset, experiment_name, seed and save_config at the top-level as they are core args\n            if k in module_args[module_name] and k not in {\n                \"version\",\n                \"dataset\",\n                \"experiment_name\",\n                \"seed\",\n                \"save_config\",\n            }:\n                if module_name not in out_dict:\n                    out_dict[module_name] = {}\n                v = args_dict.pop(k)\n                if v is not None:\n                    out_dict[module_name][k] = v\n\n    # Assemble the final dictionary in YAML-compliant form\n    return {**({\"run_type\": run_type} if run_type else {}), **args_dict, **out_dict}\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.get_default_and_required_args","title":"get_default_and_required_args(top_parser, module_parsers)","text":"

Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.

Parameters:

Name Type Description Default top_parser ArgumentParser

The top-level parser (contains common arguments).

required module_parsers dict[str, ArgumentParser]

The dict of module-level parsers mapped to their names.

required

Returns:

Type Description tuple[dict[str, Any], list[str]]

A tuple containing two elements: - A dictionary containing all arguments and their default values. - A list of key-value-pairs of the required arguments and their associated module.

Source code in src/nhssynth/cli/config.py
def get_default_and_required_args(\n    top_parser: argparse.ArgumentParser,\n    module_parsers: dict[str, argparse.ArgumentParser],\n) -> tuple[dict[str, Any], list[str]]:\n    \"\"\"\n    Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.\n\n    Args:\n        top_parser: The top-level parser (contains common arguments).\n        module_parsers: The dict of module-level parsers mapped to their names.\n\n    Returns:\n        A tuple containing two elements:\n            - A dictionary containing all arguments and their default values.\n            - A list of key-value-pairs of the required arguments and their associated module.\n    \"\"\"\n    all_actions = {\"top-level\": top_parser._actions} | {m: p._actions for m, p in module_parsers.items()}\n    defaults = {}\n    required_args = []\n    for module, actions in all_actions.items():\n        for action in actions:\n            if action.dest not in [\"help\", \"==SUPPRESS==\"]:\n                defaults[action.dest] = action.default\n                if action.required:\n                    required_args.append({\"arg\": action.dest, \"module\": module})\n    return defaults, required_args\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.get_modules_to_run","title":"get_modules_to_run(executor)","text":"

Get the list of modules to run from the passed executor function.

Parameters:

Name Type Description Default executor Callable

The executor function to run.

required

Returns:

Type Description list[str]

A list of module names to run.

Source code in src/nhssynth/cli/config.py
def get_modules_to_run(executor: Callable) -> list[str]:\n    \"\"\"\n    Get the list of modules to run from the passed executor function.\n\n    Args:\n        executor: The executor function to run.\n\n    Returns:\n        A list of module names to run.\n    \"\"\"\n    if executor == run_pipeline:\n        return PIPELINE\n    else:\n        return [get_key_by_value({mn: mc.func for mn, mc in MODULE_MAP.items()}, executor)]\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.read_config","title":"read_config(args, parser, all_subparsers)","text":"

Hierarchically assembles a config argparse.Namespace object for the inferred modules to run and execute, given a file.

  1. Load the YAML file containing the config to read from
  2. Check a valid run_type is specified or infer it and determine the list of modules_to_run
  3. Establish the appropriate default configuration set of arguments from the parser and all_subparsers for the determined modules_to_run
  4. Overwrite these with the specified (sub)set of config in the YAML file
  5. Overwrite again with passed command-line args (these are considered 'overrides')
  6. Run the appropriate module(s) or pipeline with the resulting configuration Namespace object

Parameters:

Name Type Description Default args Namespace

Namespace object containing arguments from the command line

required parser ArgumentParser

top-level ArgumentParser object containing common arguments

required all_subparsers dict[str, ArgumentParser]

dictionary of ArgumentParser objects, one for each module

required

Returns:

Type Description Namespace

A Namespace object containing the assembled configuration settings

Raises:

Type Description AssertionError

if any required arguments are missing from the configuration file / overrides

Source code in src/nhssynth/cli/config.py
def read_config(\n    args: argparse.Namespace,\n    parser: argparse.ArgumentParser,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> argparse.Namespace:\n    \"\"\"\n    Hierarchically assembles a config `argparse.Namespace` object for the inferred modules to run and execute, given a file.\n\n    1. Load the YAML file containing the config to read from\n    2. Check a valid `run_type` is specified or infer it and determine the list of `modules_to_run`\n    3. Establish the appropriate default configuration set of arguments from the `parser` and `all_subparsers` for the determined `modules_to_run`\n    4. Overwrite these with the specified (sub)set of config in the YAML file\n    5. Overwrite again with passed command-line `args` (these are considered 'overrides')\n    6. Run the appropriate module(s) or pipeline with the resulting configuration `Namespace` object\n\n    Args:\n        args: Namespace object containing arguments from the command line\n        parser: top-level `ArgumentParser` object containing common arguments\n        all_subparsers: dictionary of `ArgumentParser` objects, one for each module\n\n    Returns:\n        A Namespace object containing the assembled configuration settings\n\n    Raises:\n        AssertionError: if any required arguments are missing from the configuration file / overrides\n    \"\"\"\n    # Open the passed yaml file and load into a dictionary\n    with open(f\"config/{args.input_config}.yaml\") as stream:\n        config_dict = yaml.safe_load(stream)\n\n    valid_run_types = [x for x in all_subparsers.keys() if x != \"config\"]\n\n    version = config_dict.pop(\"version\", None)\n    if version and version != version(\"nhssynth\"):\n        warnings.warn(\n            f\"This config file's specified version ({version}) does not match the currently installed version of nhssynth ({version('nhssynth')}), results may differ.\"\n        )\n    elif not version:\n        version = ver(\"nhssynth\")\n\n    run_type = config_dict.pop(\"run_type\", None)\n\n    if run_type == \"pipeline\":\n        modules_to_run = PIPELINE\n    else:\n        modules_to_run = [x for x in config_dict.keys() | {run_type} if x in valid_run_types]\n        if not args.custom_pipeline:\n            modules_to_run = sorted(modules_to_run, key=lambda x: PIPELINE.index(x))\n\n    if not modules_to_run:\n        warnings.warn(\n            \"Missing or invalid `run_type` and / or module specification hierarchy in `config/{args.input_config}.yaml`, defaulting to a full run of the pipeline\"\n        )\n        modules_to_run = PIPELINE\n\n    # Get all possible default arguments by scraping the top level `parser` and the appropriate sub-parser for the `run_type`\n    args_dict, required_args = get_default_and_required_args(\n        parser, filter_dict(all_subparsers, modules_to_run, include=True)\n    )\n\n    # Find the non-default arguments amongst passed `args` by seeing which of them are different to the entries of `args_dict`\n    non_default_passed_args_dict = {\n        k: v\n        for k, v in vars(args).items()\n        if k in [\"input_config\", \"custom_pipeline\"] or (k in args_dict and k != \"func\" and v != args_dict[k])\n    }\n\n    # Overwrite the default arguments with the ones from the yaml file\n    args_dict.update(flatten_dict(config_dict))\n\n    # Overwrite the result of the above with any non-default CLI args\n    args_dict.update(non_default_passed_args_dict)\n\n    # Create a new Namespace using the assembled dictionary\n    new_args = argparse.Namespace(**args_dict)\n    assert getattr(\n        new_args, \"dataset\"\n    ), \"No dataset specified in the passed config file, provide one with the `--dataset` argument or add it to the config file\"\n    assert all(\n        getattr(new_args, req_arg[\"arg\"]) for req_arg in required_args\n    ), f\"Required arguments are missing from the passed config file: {[ra['module'] + ':' + ra['arg'] for ra in required_args if not getattr(new_args, ra['arg'])]}\"\n\n    # Run the appropriate execution function(s)\n    if not new_args.seed:\n        warnings.warn(\"No seed has been specified, meaning the results of this run may not be reproducible.\")\n    new_args.version = version\n    new_args.modules_to_run = modules_to_run\n    new_args.module_handover = {}\n    for module in new_args.modules_to_run:\n        MODULE_MAP[module](new_args)\n\n    return new_args\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.write_config","title":"write_config(args, all_subparsers)","text":"

Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by args.save_config_path.

Parameters:

Name Type Description Default args Namespace

A namespace containing the run's configuration.

required all_subparsers dict[str, ArgumentParser]

A dictionary containing all subparsers for the config args.

required Source code in src/nhssynth/cli/config.py
def write_config(\n    args: argparse.Namespace,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> None:\n    \"\"\"\n    Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by `args.save_config_path`.\n\n    Args:\n        args: A namespace containing the run's configuration.\n        all_subparsers: A dictionary containing all subparsers for the config args.\n    \"\"\"\n    experiment_name = args.experiment_name\n    args_dict = assemble_config(args, all_subparsers)\n    with open(f\"experiments/{experiment_name}/config_{experiment_name}.yaml\", \"w\") as yaml_file:\n        yaml.dump(args_dict, yaml_file, default_flow_style=False, sort_keys=False)\n
"},{"location":"reference/cli/model_arguments/","title":"model_arguments","text":"

Define arguments for each of the model classes.

"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_gan_args","title":"add_gan_args(group, overrides=False)","text":"

Adds arguments to an existing group for the GAN model.

Source code in src/nhssynth/cli/model_arguments.py
def add_gan_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group for the GAN model.\"\"\"\n    group.add_argument(\n        \"--n-units-conditional\",\n        type=int,\n        help=\"the number of units in the conditional layer\",\n    )\n    group.add_argument(\n        \"--generator-n-layers-hidden\",\n        type=int,\n        help=\"the number of hidden layers in the generator\",\n    )\n    group.add_argument(\n        \"--generator-n-units-hidden\",\n        type=int,\n        help=\"the number of units in each hidden layer of the generator\",\n    )\n    group.add_argument(\n        \"--generator-activation\",\n        type=str,\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the generator\",\n    )\n    group.add_argument(\n        \"--generator-batch-norm\",\n        action=\"store_true\",\n        help=\"whether to use batch norm in the generator\",\n    )\n    group.add_argument(\n        \"--generator-dropout\",\n        type=float,\n        help=\"the dropout rate in the generator\",\n    )\n    group.add_argument(\n        \"--generator-lr\",\n        type=float,\n        help=\"the learning rate for the generator\",\n    )\n    group.add_argument(\n        \"--generator-residual\",\n        action=\"store_true\",\n        help=\"whether to use residual connections in the generator\",\n    )\n    group.add_argument(\n        \"--generator-opt-betas\",\n        type=float,\n        nargs=2,\n        help=\"the beta values for the generator optimizer\",\n    )\n    group.add_argument(\n        \"--discriminator-n-layers-hidden\",\n        type=int,\n        help=\"the number of hidden layers in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-n-units-hidden\",\n        type=int,\n        help=\"the number of units in each hidden layer of the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-activation\",\n        type=str,\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-batch-norm\",\n        action=\"store_true\",\n        help=\"whether to use batch norm in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-dropout\",\n        type=float,\n        help=\"the dropout rate in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-lr\",\n        type=float,\n        help=\"the learning rate for the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-opt-betas\",\n        type=float,\n        nargs=2,\n        help=\"the beta values for the discriminator optimizer\",\n    )\n    group.add_argument(\n        \"--clipping-value\",\n        type=float,\n        help=\"the clipping value for the discriminator\",\n    )\n    group.add_argument(\n        \"--lambda-gradient-penalty\",\n        type=float,\n        help=\"the gradient penalty coefficient\",\n    )\n
"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_model_specific_args","title":"add_model_specific_args(group, name, overrides=False)","text":"

Adds arguments to an existing group according to name.

Source code in src/nhssynth/cli/model_arguments.py
def add_model_specific_args(group: argparse._ArgumentGroup, name: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group according to `name`.\"\"\"\n    if name == \"VAE\":\n        add_vae_args(group, overrides)\n    elif name == \"GAN\":\n        add_gan_args(group, overrides)\n    elif name == \"TabularGAN\":\n        add_tabular_gan_args(group, overrides)\n
"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_vae_args","title":"add_vae_args(group, overrides=False)","text":"

Adds arguments to an existing group for the VAE model.

Source code in src/nhssynth/cli/model_arguments.py
def add_vae_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group for the VAE model.\"\"\"\n    group.add_argument(\n        \"--encoder-latent-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the latent dimension of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-hidden-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the hidden dimension of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-activation\",\n        type=str,\n        nargs=\"+\",\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-learning-rate\",\n        type=float,\n        nargs=\"+\",\n        help=\"the learning rate for the encoder\",\n    )\n    group.add_argument(\n        \"--decoder-latent-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the latent dimension of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-hidden-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the hidden dimension of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-activation\",\n        type=str,\n        nargs=\"+\",\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-learning-rate\",\n        type=float,\n        nargs=\"+\",\n        help=\"the learning rate for the decoder\",\n    )\n    group.add_argument(\n        \"--shared-optimizer\",\n        action=\"store_true\",\n        help=\"whether to use a shared optimizer for the encoder and decoder\",\n    )\n
"},{"location":"reference/cli/module_arguments/","title":"module_arguments","text":"

Define arguments for each of the modules' CLI sub-parsers.

"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.AllChoicesDefault","title":"AllChoicesDefault","text":"

Bases: Action

Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied: (i.e. user passes --metrics with no follow up list of metric groups => all metric groups will be executed).

Notes

1) If no option_string is supplied: set to default value (self.default) 2) If option_string is supplied: a) If values are supplied, set to list of values b) If no values are supplied, set to self.const, if self.const is not set, set to self.default

Source code in src/nhssynth/cli/module_arguments.py
class AllChoicesDefault(argparse.Action):\n    \"\"\"\n    Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied:\n    (i.e. user passes `--metrics` with no follow up list of metric groups => all metric groups will be executed).\n\n    Notes:\n        1) If no `option_string` is supplied: set to default value (`self.default`)\n        2) If `option_string` is supplied:\n            a) If `values` are supplied, set to list of values\n            b) If no `values` are supplied, set to `self.const`, if `self.const` is not set, set to `self.default`\n    \"\"\"\n\n    def __call__(self, parser, namespace, values=None, option_string=None):\n        if values:\n            setattr(namespace, self.dest, values)\n        elif option_string:\n            setattr(namespace, self.dest, self.const if self.const else self.default)\n        else:\n            setattr(namespace, self.dest, self.default)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_dataloader_args","title":"add_dataloader_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing dataloader module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_dataloader_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing dataloader module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--data-dir\",\n        type=str,\n        default=\"./data\",\n        help=\"the directory containing the chosen dataset\",\n    )\n    group.add_argument(\n        \"--index-col\",\n        default=None,\n        nargs=\"*\",\n        help=\"indicate the name of the index column(s) in the csv file, such that pandas can index by it\",\n    )\n    group.add_argument(\n        \"--constraint-graph\",\n        type=str,\n        default=\"_constraint_graph\",\n        help=\"the name of the html file to write the constraint graph to, defaults to `<DATASET>_constraint_graph`\",\n    )\n    group.add_argument(\n        \"--collapse-yaml\",\n        action=\"store_true\",\n        help=\"use aliases and anchors in the output metadata yaml, this will make it much more compact\",\n    )\n    group.add_argument(\n        \"--missingness\",\n        type=str,\n        default=\"augment\",\n        choices=MISSINGNESS_STRATEGIES,\n        help=\"how to handle missing values in the dataset\",\n    )\n    group.add_argument(\n        \"--impute\",\n        type=str,\n        default=None,\n        help=\"the imputation strategy to use, ONLY USED if <MISSINGNESS> is set to 'impute', choose from: 'mean', 'median', 'mode', or any specific value (e.g. '0')\",\n    )\n    group.add_argument(\n        \"--write-csv\",\n        action=\"store_true\",\n        help=\"write the transformed real data to a csv file\",\n    )\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_evaluation_args","title":"add_evaluation_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing evaluation module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_evaluation_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing evaluation module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--downstream-tasks\",\n        \"--tasks\",\n        action=\"store_true\",\n        help=\"run the downstream tasks evaluation\",\n    )\n    group.add_argument(\n        \"--tasks-dir\",\n        type=str,\n        default=\"./tasks\",\n        help=\"the directory containing the downstream tasks to run, this directory must contain a folder called <DATASET> containing the tasks to run\",\n    )\n    group.add_argument(\n        \"--aequitas\",\n        action=\"store_true\",\n        help=\"run the aequitas fairness evaluation (note this runs for each of the downstream tasks)\",\n    )\n    group.add_argument(\n        \"--aequitas-attributes\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the attributes to use for the aequitas fairness evaluation, defaults to all attributes\",\n    )\n    group.add_argument(\n        \"--key-numerical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the numerical key field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--sensitive-numerical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the numerical sensitive field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--key-categorical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the categorical key field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--sensitive-categorical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the categorical sensitive field attributes to use for SDV privacy evaluations\",\n    )\n    for name in METRIC_CHOICES:\n        generate_evaluation_arg(group, name)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_model_args","title":"add_model_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing model module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing model module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--architecture\",\n        type=str,\n        nargs=\"+\",\n        default=[\"VAE\"],\n        choices=MODELS,\n        help=\"the model architecture(s) to train\",\n    )\n    group.add_argument(\n        \"--repeats\",\n        type=int,\n        default=1,\n        help=\"how many times to repeat the training process per model architecture (<SEED> is incremented each time)\",\n    )\n    group.add_argument(\n        \"--batch-size\",\n        type=int,\n        nargs=\"+\",\n        default=32,\n        help=\"the batch size for the model\",\n    )\n    group.add_argument(\n        \"--num-epochs\",\n        type=int,\n        nargs=\"+\",\n        default=100,\n        help=\"number of epochs to train for\",\n    )\n    group.add_argument(\n        \"--patience\",\n        type=int,\n        nargs=\"+\",\n        default=5,\n        help=\"how many epochs the model is allowed to train for without improvement\",\n    )\n    group.add_argument(\n        \"--displayed-metrics\",\n        type=str,\n        nargs=\"+\",\n        default=[],\n        help=\"metrics to display during training of the model, when set to `None`, all metrics are displayed\",\n    )\n    group.add_argument(\n        \"--use-gpu\",\n        action=\"store_true\",\n        help=\"use the GPU for training\",\n    )\n    group.add_argument(\n        \"--num-samples\",\n        type=int,\n        default=None,\n        help=\"the number of samples to generate from the model, defaults to the size of the original dataset\",\n    )\n    privacy_group = parser.add_argument_group(title=\"model privacy options\")\n    privacy_group.add_argument(\n        \"--target-epsilon\",\n        type=float,\n        nargs=\"+\",\n        default=1.0,\n        help=\"the target epsilon for differential privacy\",\n    )\n    privacy_group.add_argument(\n        \"--target-delta\",\n        type=float,\n        nargs=\"+\",\n        help=\"the target delta for differential privacy, defaults to `1 / len(dataset)` if not specified\",\n    )\n    privacy_group.add_argument(\n        \"--max-grad-norm\",\n        type=float,\n        nargs=\"+\",\n        default=5.0,\n        help=\"the clipping threshold for gradients (only relevant under differential privacy)\",\n    )\n    privacy_group.add_argument(\n        \"--secure-mode\",\n        action=\"store_true\",\n        help=\"Enable secure RNG via the `csprng` package to make privacy guarantees more robust, comes at a cost of performance and reproducibility\",\n    )\n    for model_name in MODELS.keys():\n        model_group = parser.add_argument_group(title=f\"{model_name}-specific options\")\n        add_model_specific_args(model_group, model_name, overrides=overrides)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_plotting_args","title":"add_plotting_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing plotting module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_plotting_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing plotting module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--plot-quality\",\n        action=\"store_true\",\n        help=\"plot the SDV quality report\",\n    )\n    group.add_argument(\n        \"--plot-diagnostic\",\n        action=\"store_true\",\n        help=\"plot the SDV diagnostic report\",\n    )\n    group.add_argument(\n        \"--plot-sdv-report\",\n        action=\"store_true\",\n        help=\"plot the SDV report\",\n    )\n    group.add_argument(\n        \"--plot-tsne\",\n        action=\"store_true\",\n        help=\"plot the t-SNE embeddings of the real and synthetic data\",\n    )\n
"},{"location":"reference/cli/module_setup/","title":"module_setup","text":"

Specify all CLI-accessible modules and their configurations, the pipeline to run by default, and define special functions for the config and pipeline CLI option trees.

"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.ModuleConfig","title":"ModuleConfig","text":"

Represents a module's configuration, containing the following attributes:

Attributes:

Name Type Description func

A callable that executes the module's functionality.

add_args

A callable that populates the module's sub-parser arguments.

description

A description of the module's functionality.

help

A help message for the module's command-line interface.

common_parsers

A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.

Source code in src/nhssynth/cli/module_setup.py
class ModuleConfig:\n    \"\"\"\n    Represents a module's configuration, containing the following attributes:\n\n    Attributes:\n        func: A callable that executes the module's functionality.\n        add_args: A callable that populates the module's sub-parser arguments.\n        description: A description of the module's functionality.\n        help: A help message for the module's command-line interface.\n        common_parsers: A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.\n    \"\"\"\n\n    def __init__(\n        self,\n        func: Callable[..., argparse.Namespace],\n        add_args: Callable[..., None],\n        description: str,\n        help: str,\n        common_parsers: Optional[list[str]] = None,\n        no_seed: bool = False,\n    ) -> None:\n        self.func = func\n        self.add_args = add_args\n        self.description = description\n        self.help = help\n        self.common_parsers = [\"core\", \"seed\"] if not no_seed else [\"core\"]\n        if common_parsers:\n            assert set(common_parsers) <= COMMON_PARSERS.keys(), \"Invalid common parser(s) specified.\"\n            # merge the below two assert statements\n            assert (\n                \"core\" not in common_parsers and \"seed\" not in common_parsers\n            ), \"The 'seed' and 'core' parser groups are automatically added to all modules, remove the from `ModuleConfig`s.\"\n            self.common_parsers += common_parsers\n\n    def __call__(self, args: argparse.Namespace) -> argparse.Namespace:\n        return self.func(args)\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_config_args","title":"add_config_args(parser)","text":"

Adds arguments to parser relating to configuration file handling and module-specific config overrides.

Source code in src/nhssynth/cli/module_setup.py
def add_config_args(parser: argparse.ArgumentParser) -> None:\n    \"\"\"Adds arguments to `parser` relating to configuration file handling and module-specific config overrides.\"\"\"\n    parser.add_argument(\n        \"-c\",\n        \"--input-config\",\n        required=True,\n        help=\"specify the config file name\",\n    )\n    parser.add_argument(\n        \"-cp\",\n        \"--custom-pipeline\",\n        action=\"store_true\",\n        help=\"infer a custom pipeline running order of modules from the config\",\n    )\n    for module_name in PIPELINE:\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} option overrides\", overrides=True)\n    for module_name in VALID_MODULES - set(PIPELINE):\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} options overrides\", overrides=True)\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_pipeline_args","title":"add_pipeline_args(parser)","text":"

Adds arguments to parser for each module in the pipeline.

Source code in src/nhssynth/cli/module_setup.py
def add_pipeline_args(parser: argparse.ArgumentParser) -> None:\n    \"\"\"Adds arguments to `parser` for each module in the pipeline.\"\"\"\n    for module_name in PIPELINE:\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} options\")\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_subparser","title":"add_subparser(subparsers, name, module_config)","text":"

Add a subparser to an argparse argument parser.

Parameters:

Name Type Description Default subparsers _SubParsersAction

The subparsers action to which the subparser will be added.

required name str

The name of the subparser.

required module_config ModuleConfig

A ModuleConfig object containing information about the subparser, including a function to execute and a function to add arguments.

required

Returns:

Type Description ArgumentParser

The newly created subparser.

Source code in src/nhssynth/cli/module_setup.py
def add_subparser(\n    subparsers: argparse._SubParsersAction,\n    name: str,\n    module_config: ModuleConfig,\n) -> argparse.ArgumentParser:\n    \"\"\"\n    Add a subparser to an argparse argument parser.\n\n    Args:\n        subparsers: The subparsers action to which the subparser will be added.\n        name: The name of the subparser.\n        module_config: A [`ModuleConfig`][nhssynth.cli.module_setup.ModuleConfig] object containing information about the subparser, including a function to execute and a function to add arguments.\n\n    Returns:\n        The newly created subparser.\n    \"\"\"\n    parent_parsers = get_parent_parsers(name, module_config.common_parsers)\n    parser = subparsers.add_parser(\n        name=name,\n        description=module_config.description,\n        help=module_config.help,\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n        parents=parent_parsers,\n    )\n    if name not in {\"pipeline\", \"config\"}:\n        module_config.add_args(parser, f\"{name} options\")\n    else:\n        module_config.add_args(parser)\n    parser.set_defaults(func=module_config.func)\n    return parser\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.get_parent_parsers","title":"get_parent_parsers(name, module_parsers)","text":"

Get a list of parent parsers for a given module, based on the module's common_parsers attribute.

Source code in src/nhssynth/cli/module_setup.py
def get_parent_parsers(name: str, module_parsers: list[str]) -> list[argparse.ArgumentParser]:\n    \"\"\"Get a list of parent parsers for a given module, based on the module's `common_parsers` attribute.\"\"\"\n    if name in {\"pipeline\", \"config\"}:\n        return [p(name == \"config\") for p in COMMON_PARSERS.values()]\n    elif name == \"dashboard\":\n        return [COMMON_PARSERS[pn](True) for pn in module_parsers]\n    else:\n        return [COMMON_PARSERS[pn]() for pn in module_parsers]\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.run_pipeline","title":"run_pipeline(args)","text":"

Runs the specified pipeline of modules with the passed configuration args.

Source code in src/nhssynth/cli/module_setup.py
def run_pipeline(args: argparse.Namespace) -> None:\n    \"\"\"Runs the specified pipeline of modules with the passed configuration `args`.\"\"\"\n    print(\"Running full pipeline...\")\n    args.modules_to_run = PIPELINE\n    for module_name in PIPELINE:\n        args = MODULE_MAP[module_name](args)\n
"},{"location":"reference/cli/run/","title":"run","text":""},{"location":"reference/common/","title":"common","text":""},{"location":"reference/common/common/","title":"common","text":"

Common functions for all modules.

"},{"location":"reference/common/common/#nhssynth.common.common.set_seed","title":"set_seed(seed=None)","text":"

(Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.

Parameters:

Name Type Description Default seed Optional[int]

The seed to set.

None Source code in src/nhssynth/common/common.py
def set_seed(seed: Optional[int] = None) -> None:\n    \"\"\"\n    (Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.\n\n    Args:\n        seed: The seed to set.\n    \"\"\"\n    if seed:\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        random.seed(seed)\n
"},{"location":"reference/common/constants/","title":"constants","text":"

Define all of the common constants used throughout the project.

"},{"location":"reference/common/debugging/","title":"debugging","text":"

Debugging utilities.

"},{"location":"reference/common/dicts/","title":"dicts","text":"

Common functions for working with dictionaries.

"},{"location":"reference/common/dicts/#nhssynth.common.dicts.filter_dict","title":"filter_dict(d, filter_keys, include=False)","text":"

Given a dictionary, return a new dictionary either including or excluding keys in a given filter set.

Parameters:

Name Type Description Default d dict

A dictionary to filter.

required filter_keys Union[set, list]

A list or set of keys to either include or exclude.

required include bool

Determine whether to return a dictionary including or excluding keys in filter.

False

Returns:

Type Description dict

A filtered dictionary.

Examples:

>>> d = {'a': 1, 'b': 2, 'c': 3}\n>>> filter_dict(d, {'a', 'b'})\n{'c': 3}\n>>> filter_dict(d, {'a', 'b'}, include=True)\n{'a': 1, 'b': 2}\n
Source code in src/nhssynth/common/dicts.py
def filter_dict(d: dict, filter_keys: Union[set, list], include: bool = False) -> dict:\n    \"\"\"\n    Given a dictionary, return a new dictionary either including or excluding keys in a given `filter` set.\n\n    Args:\n        d: A dictionary to filter.\n        filter_keys: A list or set of keys to either include or exclude.\n        include: Determine whether to return a dictionary including or excluding keys in `filter`.\n\n    Returns:\n        A filtered dictionary.\n\n    Examples:\n        >>> d = {'a': 1, 'b': 2, 'c': 3}\n        >>> filter_dict(d, {'a', 'b'})\n        {'c': 3}\n        >>> filter_dict(d, {'a', 'b'}, include=True)\n        {'a': 1, 'b': 2}\n    \"\"\"\n    if include:\n        filtered_keys = set(filter_keys) & set(d.keys())\n    else:\n        filtered_keys = set(d.keys()) - set(filter_keys)\n    return {k: v for k, v in d.items() if k in filtered_keys}\n
"},{"location":"reference/common/dicts/#nhssynth.common.dicts.flatten_dict","title":"flatten_dict(d)","text":"

Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.

Parameters:

Name Type Description Default d dict[str, Any]

A dictionary with potentially nested keys.

required

Returns:

Type Description dict[str, Any]

A flattened dictionary.

Raises:

Type Description ValueError

If duplicate keys are found in the flattened dictionary.

Examples:

>>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}\n>>> flatten_dict(d)\n{'a': 1, 'c': 2, 'e': 3}\n
Source code in src/nhssynth/common/dicts.py
def flatten_dict(d: dict[str, Any]) -> dict[str, Any]:\n    \"\"\"\n    Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.\n\n    Args:\n        d: A dictionary with potentially nested keys.\n\n    Returns:\n        A flattened dictionary.\n\n    Raises:\n        ValueError: If duplicate keys are found in the flattened dictionary.\n\n    Examples:\n        >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}\n        >>> flatten_dict(d)\n        {'a': 1, 'c': 2, 'e': 3}\n    \"\"\"\n    items = []\n    for k, v in d.items():\n        if isinstance(v, dict):\n            items.extend(flatten_dict(v).items())\n        else:\n            items.append((k, v))\n    if len(set([p[0] for p in items])) != len(items):\n        raise ValueError(\"Duplicate keys found in flattened dictionary\")\n    return dict(items)\n
"},{"location":"reference/common/dicts/#nhssynth.common.dicts.get_key_by_value","title":"get_key_by_value(d, value)","text":"

Find the first key in a dictionary with a given value.

Parameters:

Name Type Description Default d dict

A dictionary to search through.

required value Any

The value to search for.

required

Returns:

Type Description Union[Any, None]

The first key in d with the value value, or None if no such key exists.

Examples:

>>> d = {'a': 1, 'b': 2, 'c': 1}\n>>> get_key_by_value(d, 2)\n'b'\n>>> get_key_by_value(d, 3)\nNone\n
Source code in src/nhssynth/common/dicts.py
def get_key_by_value(d: dict, value: Any) -> Union[Any, None]:\n    \"\"\"\n    Find the first key in a dictionary with a given value.\n\n    Args:\n        d: A dictionary to search through.\n        value: The value to search for.\n\n    Returns:\n        The first key in `d` with the value `value`, or `None` if no such key exists.\n\n    Examples:\n        >>> d = {'a': 1, 'b': 2, 'c': 1}\n        >>> get_key_by_value(d, 2)\n        'b'\n        >>> get_key_by_value(d, 3)\n        None\n\n    \"\"\"\n    for key, val in d.items():\n        if val == value:\n            return key\n    return None\n
"},{"location":"reference/common/io/","title":"io","text":"

Common building-block functions for handling module input and output.

"},{"location":"reference/common/io/#nhssynth.common.io.check_exists","title":"check_exists(fns, dir)","text":"

Checks if the files in fns exist in dir.

Parameters:

Name Type Description Default fns list[str]

The list of files to check.

required dir Path

The directory the files should exist in.

required

Raises:

Type Description FileNotFoundError

If any of the files in fns do not exist in dir.

Source code in src/nhssynth/common/io.py
def check_exists(fns: list[str], dir: Path) -> None:\n    \"\"\"\n    Checks if the files in `fns` exist in `dir`.\n\n    Args:\n        fns: The list of files to check.\n        dir: The directory the files should exist in.\n\n    Raises:\n        FileNotFoundError: If any of the files in `fns` do not exist in `dir`.\n    \"\"\"\n    for fn in fns:\n        if not (dir / fn).exists():\n            raise FileNotFoundError(f\"File {fn} does not exist at {dir}.\")\n
"},{"location":"reference/common/io/#nhssynth.common.io.consistent_ending","title":"consistent_ending(fn, ending='.pkl', suffix='')","text":"

Ensures that the filename fn ends with ending. If not, removes any existing ending and appends ending.

Parameters:

Name Type Description Default fn str

The filename to check.

required ending str

The desired ending to check for. Default is \".pkl\".

'.pkl' suffix str

A suffix to append to the filename before the ending.

''

Returns:

Type Description str

The filename with the correct ending and potentially an inserted suffix.

Source code in src/nhssynth/common/io.py
def consistent_ending(fn: str, ending: str = \".pkl\", suffix: str = \"\") -> str:\n    \"\"\"\n    Ensures that the filename `fn` ends with `ending`. If not, removes any existing ending and appends `ending`.\n\n    Args:\n        fn: The filename to check.\n        ending: The desired ending to check for. Default is \".pkl\".\n        suffix: A suffix to append to the filename before the ending.\n\n    Returns:\n        The filename with the correct ending and potentially an inserted suffix.\n    \"\"\"\n    path_fn = Path(fn)\n    return str(path_fn.parent / path_fn.stem) + (\"_\" if suffix else \"\") + suffix + ending\n
"},{"location":"reference/common/io/#nhssynth.common.io.consistent_endings","title":"consistent_endings(args)","text":"

Wrapper around consistent_ending to apply it to a list of filenames.

Parameters:

Name Type Description Default args list[Union[str, tuple[str, str], tuple[str, str, str]]]

The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.

required

Returns:

Type Description list[str]

The list of filenames with the correct endings.

Source code in src/nhssynth/common/io.py
def consistent_endings(args: list[Union[str, tuple[str, str], tuple[str, str, str]]]) -> list[str]:\n    \"\"\"\n    Wrapper around `consistent_ending` to apply it to a list of filenames.\n\n    Args:\n        args: The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.\n\n    Returns:\n        The list of filenames with the correct endings.\n    \"\"\"\n    return list(consistent_ending(arg) if isinstance(arg, str) else consistent_ending(*arg) for arg in args)\n
"},{"location":"reference/common/io/#nhssynth.common.io.experiment_io","title":"experiment_io(experiment_name, dir_experiments='experiments')","text":"

Create an experiment's directory and return the path.

Parameters:

Name Type Description Default experiment_name str

The name of the experiment.

required dir_experiments str

The name of the directory containing all experiments.

'experiments'

Returns:

Type Description str

The path to the experiment directory.

Source code in src/nhssynth/common/io.py
def experiment_io(experiment_name: str, dir_experiments: str = \"experiments\") -> str:\n    \"\"\"\n    Create an experiment's directory and return the path.\n\n    Args:\n        experiment_name: The name of the experiment.\n        dir_experiments: The name of the directory containing all experiments.\n\n    Returns:\n        The path to the experiment directory.\n    \"\"\"\n    dir_experiment = Path(dir_experiments) / experiment_name\n    dir_experiment.mkdir(parents=True, exist_ok=True)\n    return dir_experiment\n
"},{"location":"reference/common/io/#nhssynth.common.io.potential_suffix","title":"potential_suffix(fn, fn_base)","text":"

Checks if fn is a suffix (starts with an underscore) to append to fn_base, or a filename in its own right.

Parameters:

Name Type Description Default fn str

The filename / potential suffix to append to fn_base.

required fn_base str

The name of the file the suffix would attach to.

required

Returns:

Type Description str

The appropriately processed fn

Source code in src/nhssynth/common/io.py
def potential_suffix(fn: str, fn_base: str) -> str:\n    \"\"\"\n    Checks if `fn` is a suffix (starts with an underscore) to append to `fn_base`, or a filename in its own right.\n\n    Args:\n        fn: The filename / potential suffix to append to `fn_base`.\n        fn_base: The name of the file the suffix would attach to.\n\n    Returns:\n        The appropriately processed `fn`\n    \"\"\"\n    fn_base = Path(fn_base).stem\n    if fn[0] == \"_\":\n        return fn_base + fn\n    else:\n        return fn\n
"},{"location":"reference/common/io/#nhssynth.common.io.potential_suffixes","title":"potential_suffixes(fns, fn_base)","text":"

Wrapper around potential_suffix to apply it to a list of filenames.

Parameters:

Name Type Description Default fns list[str]

The list of filenames / potential suffixes to append to fn_base.

required fn_base str

The name of the file the suffixes would attach to.

required Source code in src/nhssynth/common/io.py
def potential_suffixes(fns: list[str], fn_base: str) -> list[str]:\n    \"\"\"\n    Wrapper around `potential_suffix` to apply it to a list of filenames.\n\n    Args:\n        fns: The list of filenames / potential suffixes to append to `fn_base`.\n        fn_base: The name of the file the suffixes would attach to.\n    \"\"\"\n    return list(potential_suffix(fn, fn_base) for fn in fns)\n
"},{"location":"reference/common/io/#nhssynth.common.io.warn_if_path_supplied","title":"warn_if_path_supplied(fns, dir)","text":"

Warns if the files in fns include directory separators.

Parameters:

Name Type Description Default fns list[str]

The list of files to check.

required dir Path

The directory the files should exist in.

required

Warns:

Type Description UserWarning

when the path to any of the files in fns includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.

Source code in src/nhssynth/common/io.py
def warn_if_path_supplied(fns: list[str], dir: Path) -> None:\n    \"\"\"\n    Warns if the files in `fns` include directory separators.\n\n    Args:\n        fns: The list of files to check.\n        dir: The directory the files should exist in.\n\n    Warnings:\n        UserWarning: when the path to any of the files in `fns` includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.\n    \"\"\"\n    for fn in fns:\n        if \"/\" in fn:\n            warnings.warn(\n                f\"Using the path supplied appended to {dir}, i.e. attempting to read data from {dir / fn}\",\n                UserWarning,\n            )\n
"},{"location":"reference/common/strings/","title":"strings","text":"

String manipulation functions.

"},{"location":"reference/common/strings/#nhssynth.common.strings.add_spaces_before_caps","title":"add_spaces_before_caps(string)","text":"

Adds spaces before capital letters in a string if there is a lower-case letter following it.

Parameters:

Name Type Description Default string str

The string to add spaces to.

required

Returns:

Type Description str

The string with spaces added before capital letters.

Examples:

>>> add_spaces_before_caps(\"HelloWorld\")\n'Hello World'\n>>> add_spaces_before_caps(\"HelloWorldAGAIN\")\n'Hello World AGAIN'\n
Source code in src/nhssynth/common/strings.py
def add_spaces_before_caps(string: str) -> str:\n    \"\"\"\n    Adds spaces before capital letters in a string if there is a lower-case letter following it.\n\n    Args:\n        string: The string to add spaces to.\n\n    Returns:\n        The string with spaces added before capital letters.\n\n    Examples:\n        >>> add_spaces_before_caps(\"HelloWorld\")\n        'Hello World'\n        >>> add_spaces_before_caps(\"HelloWorldAGAIN\")\n        'Hello World AGAIN'\n    \"\"\"\n    return \" \".join(re.findall(r\"[a-z]?[A-Z][a-z]+|[A-Z]+(?=[A-Z][a-z]|\\b)\", string))\n
"},{"location":"reference/common/strings/#nhssynth.common.strings.format_timedelta","title":"format_timedelta(start, finish)","text":"

Calculate and prettily format the difference between two calls to time.time().

Parameters:

Name Type Description Default start float

The start time.

required finish float

The finish time.

required

Returns:

Type Description str

A string containing the time difference in a human-readable format.

Source code in src/nhssynth/common/strings.py
def format_timedelta(start: float, finish: float) -> str:\n    \"\"\"\n    Calculate and prettily format the difference between two calls to `time.time()`.\n\n    Args:\n        start: The start time.\n        finish: The finish time.\n\n    Returns:\n        A string containing the time difference in a human-readable format.\n    \"\"\"\n    total = datetime.timedelta(seconds=finish - start)\n    hours, remainder = divmod(total.seconds, 3600)\n    minutes, seconds = divmod(remainder, 60)\n\n    if total.days > 0:\n        delta_str = f\"{total.days}d {hours}h {minutes}m {seconds}s\"\n    elif hours > 0:\n        delta_str = f\"{hours}h {minutes}m {seconds}s\"\n    elif minutes > 0:\n        delta_str = f\"{minutes}m {seconds}s\"\n    else:\n        delta_str = f\"{seconds}s\"\n    return delta_str\n
"},{"location":"reference/modules/","title":"modules","text":""},{"location":"reference/modules/dashboard/","title":"dashboard","text":""},{"location":"reference/modules/dashboard/Upload/","title":"Upload","text":""},{"location":"reference/modules/dashboard/Upload/#nhssynth.modules.dashboard.Upload.get_component","title":"get_component(args, name, component_type, text)","text":"

Generate an upload field and its functionality for a given component of the evaluations.

Parameters:

Name Type Description Default name str

The name of the component as it should be recorded in the session state and as it exists in the args.

required component_type Any

The type of the component (to ensure that only the expected object can be uploaded)

required text str

The human-readable text to display to the user as part of the element.

required Source code in src/nhssynth/modules/dashboard/Upload.py
def get_component(args: argparse.Namespace, name: str, component_type: Any, text: str) -> None:\n    \"\"\"\n    Generate an upload field and its functionality for a given component of the evaluations.\n\n    Args:\n        name: The name of the component as it should be recorded in the session state and as it exists in the args.\n        component_type: The type of the component (to ensure that only the expected object can be uploaded)\n        text: The human-readable text to display to the user as part of the element.\n    \"\"\"\n    uploaded = st.file_uploader(f\"Upload a pickle file containing a {text}\", type=\"pkl\")\n    if getattr(args, name):\n        with open(os.getcwd() + \"/\" + getattr(args, name), \"rb\") as f:\n            loaded = pickle.load(f)\n    if uploaded is not None:\n        loaded = pickle.load(uploaded)\n    if loaded is not None:\n        assert isinstance(loaded, component_type), f\"Uploaded file does not contain a {text}!\"\n        st.session_state[name] = loaded.contents\n        st.success(f\"Loaded {text}!\")\n
"},{"location":"reference/modules/dashboard/Upload/#nhssynth.modules.dashboard.Upload.parse_args","title":"parse_args()","text":"

These arguments allow a user to automatically load the required data for the dashboard from disk.

Returns:

Type Description Namespace

The parsed arguments.

Source code in src/nhssynth/modules/dashboard/Upload.py
def parse_args() -> argparse.Namespace:\n    \"\"\"\n    These arguments allow a user to automatically load the required data for the dashboard from disk.\n\n    Returns:\n        The parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"NHSSynth Evaluation Dashboard\")\n    parser.add_argument(\"--evaluations\", type=str, help=\"Path to a set of evaluations.\")\n    parser.add_argument(\"--experiments\", type=str, help=\"Path to a set of experiments.\")\n    parser.add_argument(\"--synthetic-datasets\", type=str, help=\"Path to a set of synthetic datasets.\")\n    parser.add_argument(\"--typed\", type=str, help=\"Path to a typed real dataset.\")\n    return parser.parse_args()\n
"},{"location":"reference/modules/dashboard/io/","title":"io","text":""},{"location":"reference/modules/dashboard/io/#nhssynth.modules.dashboard.io.check_input_paths","title":"check_input_paths(dir_experiment, fn_dataset, fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default dir_experiment str

The path to the experiment directory.

required fn_dataset str

The base name of the dataset.

required fn_experiments str

The filename of the collection of experiments.

required fn_synthetic_datasets str

The filename of the collection of synthetic datasets.

required fn_evaluations str

The filename of the collection of evaluations.

required

Returns:

Type Description str

The paths

Source code in src/nhssynth/modules/dashboard/io.py
def check_input_paths(\n    dir_experiment: str,\n    fn_dataset: str,\n    fn_typed: str,\n    fn_experiments: str,\n    fn_synthetic_datasets: str,\n    fn_evaluations: str,\n) -> str:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        dir_experiment: The path to the experiment directory.\n        fn_dataset: The base name of the dataset.\n        fn_experiments: The filename of the collection of experiments.\n        fn_synthetic_datasets: The filename of the collection of synthetic datasets.\n        fn_evaluations: The filename of the collection of evaluations.\n\n    Returns:\n        The paths\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations = io.consistent_endings(\n        [fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations]\n    )\n    fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations = io.potential_suffixes(\n        [fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], fn_dataset\n    )\n    io.warn_if_path_supplied([fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], dir_experiment)\n    io.check_exists([fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], dir_experiment)\n    return (\n        dir_experiment / fn_typed,\n        dir_experiment / fn_experiments,\n        dir_experiment / fn_synthetic_datasets,\n        dir_experiment / fn_evaluations,\n    )\n
"},{"location":"reference/modules/dashboard/run/","title":"run","text":""},{"location":"reference/modules/dashboard/utils/","title":"utils","text":""},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.hide_streamlit_content","title":"hide_streamlit_content()","text":"

Hide the footer message and deploy button in Streamlit.

Source code in src/nhssynth/modules/dashboard/utils.py
def hide_streamlit_content() -> None:\n    \"\"\"\n    Hide the footer message and deploy button in Streamlit.\n    \"\"\"\n    hide_streamlit_style = \"\"\"\n    <style>\n    footer {visibility: hidden;}\n    .stDeployButton {visibility: hidden;}\n    </style>\n    \"\"\"\n    st.markdown(hide_streamlit_style, unsafe_allow_html=True)\n
"},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.id_selector","title":"id_selector(df)","text":"

Select an ID from the dataframe to then operate on.

Parameters:

Name Type Description Default df DataFrame

The dataframe to select an ID from.

required

Returns:

Type Description Series

The dataset subset to only the row corresponding to the ID.

Source code in src/nhssynth/modules/dashboard/utils.py
def id_selector(df: pd.DataFrame) -> pd.Series:\n    \"\"\"\n    Select an ID from the dataframe to then operate on.\n\n    Args:\n        df: The dataframe to select an ID from.\n\n    Returns:\n        The dataset subset to only the row corresponding to the ID.\n    \"\"\"\n    architecture = st.sidebar.selectbox(\n        \"Select architecture to display\", df.index.get_level_values(\"architecture\").unique()\n    )\n    # Different architectures may have different numbers of repeats and configs\n    repeats = df.loc[architecture].index.get_level_values(\"repeat\").astype(int).unique()\n    configs = df.loc[architecture].index.get_level_values(\"config\").astype(int).unique()\n    if len(repeats) > 1:\n        repeat = st.sidebar.selectbox(\"Select repeat to display\", repeats)\n    else:\n        repeat = repeats[0]\n    if len(configs) > 1:\n        config = st.sidebar.selectbox(\"Select configuration to display\", configs)\n    else:\n        config = configs[0]\n    return df.loc[(architecture, repeat, config)]\n
"},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.subset_selector","title":"subset_selector(df)","text":"

Select a subset of the dataframe to then operate on.

Parameters:

Name Type Description Default df DataFrame

The dataframe to select a subset of.

required

Returns:

Type Description DataFrame

The subset of the dataframe.

Source code in src/nhssynth/modules/dashboard/utils.py
def subset_selector(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Select a subset of the dataframe to then operate on.\n\n    Args:\n        df: The dataframe to select a subset of.\n\n    Returns:\n        The subset of the dataframe.\n    \"\"\"\n    architectures = df.index.get_level_values(\"architecture\").unique().tolist()\n    repeats = df.index.get_level_values(\"repeat\").astype(int).unique().tolist()\n    configs = df.index.get_level_values(\"config\").astype(int).unique().tolist()\n    selected_architectures = st.sidebar.multiselect(\n        \"Select architectures to display\", architectures, default=architectures\n    )\n    selected_repeats = st.sidebar.multiselect(\"Select repeats to display\", repeats, default=repeats[0])\n    selected_configs = st.sidebar.multiselect(\"Select configurations to display\", configs, default=configs)\n    return df.loc[(selected_architectures, selected_repeats, selected_configs)]\n
"},{"location":"reference/modules/dashboard/pages/","title":"pages","text":""},{"location":"reference/modules/dashboard/pages/1_Tables/","title":"1_Tables","text":""},{"location":"reference/modules/dashboard/pages/2_Plots/","title":"2_Plots","text":""},{"location":"reference/modules/dashboard/pages/2_Plots/#nhssynth.modules.dashboard.pages.2_Plots.prepare_for_dimensionality","title":"prepare_for_dimensionality(df)","text":"

Factorize all categorical columns in a dataframe.

Source code in src/nhssynth/modules/dashboard/pages/2_Plots.py
def prepare_for_dimensionality(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"Factorize all categorical columns in a dataframe.\"\"\"\n    for col in df.columns:\n        if df[col].dtype == \"object\":\n            df[col] = pd.factorize(df[col])[0]\n        elif df[col].dtype == \"datetime64[ns]\":\n            df[col] = pd.to_numeric(df[col])\n        min_val = df[col].min()\n        max_val = df[col].max()\n        df[col] = (df[col] - min_val) / (max_val - min_val)\n    return df\n
"},{"location":"reference/modules/dashboard/pages/3_Experiment_Configurations/","title":"3_Experiment_Configurations","text":""},{"location":"reference/modules/dataloader/","title":"dataloader","text":""},{"location":"reference/modules/dataloader/constraints/","title":"constraints","text":""},{"location":"reference/modules/dataloader/io/","title":"io","text":""},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.check_input_paths","title":"check_input_paths(fn_input, fn_metadata, dir_data)","text":"

Formats the input filenames and directory for an experiment.

Parameters:

Name Type Description Default fn_input str

The input data filename.

required fn_metadata str

The metadata filename / suffix to append to fn_input.

required dir_data str

The directory that should contain both of the above.

required

Returns:

Type Description tuple[Path, str, str]

A tuple containing the correct directory path, input data filename and metadata filename (used for both in and out).

Warns:

Type Description UserWarning

When the path to fn_input includes directory separators, as this is not supported and may not work as intended.

UserWarning

When the path to fn_metadata includes directory separators, as this is not supported and may not work as intended.

Source code in src/nhssynth/modules/dataloader/io.py
def check_input_paths(\n    fn_input: str,\n    fn_metadata: str,\n    dir_data: str,\n) -> tuple[Path, str, str]:\n    \"\"\"\n    Formats the input filenames and directory for an experiment.\n\n    Args:\n        fn_input: The input data filename.\n        fn_metadata: The metadata filename / suffix to append to `fn_input`.\n        dir_data: The directory that should contain both of the above.\n\n    Returns:\n        A tuple containing the correct directory path, input data filename and metadata filename (used for both in and out).\n\n    Warnings:\n        UserWarning: When the path to `fn_input` includes directory separators, as this is not supported and may not work as intended.\n        UserWarning: When the path to `fn_metadata` includes directory separators, as this is not supported and may not work as intended.\n    \"\"\"\n    fn_input, fn_metadata = io.consistent_endings([(fn_input, \".csv\"), (fn_metadata, \".yaml\")])\n    dir_data = Path(dir_data)\n    fn_metadata = io.potential_suffix(fn_metadata, fn_input)\n    io.warn_if_path_supplied([fn_input, fn_metadata], dir_data)\n    io.check_exists([fn_input], dir_data)\n    return dir_data, fn_input, fn_metadata\n
"},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.check_output_paths","title":"check_output_paths(fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata, dir_experiment)","text":"

Formats the output filenames for an experiment.

Parameters:

Name Type Description Default fn_dataset str

The input data filename.

required fn_typed str

The typed input data filename/suffix to append to fn_dataset.

required fn_transformed str

The transformed output data filename/suffix to append to fn_dataset.

required fn_metatransformer str

The metatransformer filename/suffix to append to fn_dataset.

required fn_constraint_graph str

The constraint graph filename/suffix to append to fn_dataset.

required fn_sdv_metadata str

The SDV metadata filename/suffix to append to fn_dataset.

required dir_experiment Path

The experiment directory to write the outputs to.

required

Returns:

Type Description tuple[str, str, str]

A tuple containing the formatted output filenames.

Warns:

Type Description UserWarning

When any of the filenames include directory separators, as this is not supported and may not work as intended.

Source code in src/nhssynth/modules/dataloader/io.py
def check_output_paths(\n    fn_dataset: str,\n    fn_typed: str,\n    fn_transformed: str,\n    fn_metatransformer: str,\n    fn_constraint_graph: str,\n    fn_sdv_metadata: str,\n    dir_experiment: Path,\n) -> tuple[str, str, str]:\n    \"\"\"\n    Formats the output filenames for an experiment.\n\n    Args:\n        fn_dataset: The input data filename.\n        fn_typed: The typed input data filename/suffix to append to `fn_dataset`.\n        fn_transformed: The transformed output data filename/suffix to append to `fn_dataset`.\n        fn_metatransformer: The metatransformer filename/suffix to append to `fn_dataset`.\n        fn_constraint_graph: The constraint graph filename/suffix to append to `fn_dataset`.\n        fn_sdv_metadata: The SDV metadata filename/suffix to append to `fn_dataset`.\n        dir_experiment: The experiment directory to write the outputs to.\n\n    Returns:\n        A tuple containing the formatted output filenames.\n\n    Warnings:\n        UserWarning: When any of the filenames include directory separators, as this is not supported and may not work as intended.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = io.consistent_endings(\n        [fn_typed, fn_transformed, fn_metatransformer, (fn_constraint_graph, \".html\"), fn_sdv_metadata]\n    )\n    fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = io.potential_suffixes(\n        [fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata], fn_dataset\n    )\n    io.warn_if_path_supplied(\n        [fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata], dir_experiment\n    )\n    return fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata\n
"},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.write_data_outputs","title":"write_data_outputs(metatransformer, fn_dataset, fn_metadata, dir_experiment, args)","text":"

Writes the transformed data and metatransformer to disk.

Parameters:

Name Type Description Default metatransformer MetaTransformer

The metatransformer used to transform the data into its model-ready state.

required fn_dataset str

The base dataset filename.

required fn_metadata str

The metadata filename.

required dir_experiment Path

The experiment directory to write the outputs to.

required args Namespace

The full set of parsed command line arguments.

required

Returns:

Type Description None

The filename of the dataset used.

Source code in src/nhssynth/modules/dataloader/io.py
def write_data_outputs(\n    metatransformer: MetaTransformer,\n    fn_dataset: str,\n    fn_metadata: str,\n    dir_experiment: Path,\n    args: argparse.Namespace,\n) -> None:\n    \"\"\"\n    Writes the transformed data and metatransformer to disk.\n\n    Args:\n        metatransformer: The metatransformer used to transform the data into its model-ready state.\n        fn_dataset: The base dataset filename.\n        fn_metadata: The metadata filename.\n        dir_experiment: The experiment directory to write the outputs to.\n        args: The full set of parsed command line arguments.\n\n    Returns:\n        The filename of the dataset used.\n    \"\"\"\n    fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = check_output_paths(\n        fn_dataset,\n        args.typed,\n        args.transformed,\n        args.metatransformer,\n        args.constraint_graph,\n        args.sdv_metadata,\n        dir_experiment,\n    )\n    metatransformer.save_metadata(dir_experiment / fn_metadata, args.collapse_yaml)\n    metatransformer.save_constraint_graphs(dir_experiment / fn_constraint_graph)\n    with open(dir_experiment / fn_typed, \"wb\") as f:\n        pickle.dump(TypedDataset(metatransformer.get_typed_dataset()), f)\n    transformed_dataset = metatransformer.get_transformed_dataset()\n    transformed_dataset.to_pickle(dir_experiment / fn_transformed)\n    if args.write_csv:\n        chunks = np.array_split(transformed_dataset.index, 100)\n        for chunk, subset in enumerate(tqdm(chunks, desc=\"Writing transformed dataset to CSV\", unit=\"chunk\")):\n            if chunk == 0:\n                transformed_dataset.loc[subset].to_csv(\n                    dir_experiment / (fn_transformed[:-3] + \"csv\"), mode=\"w\", index=False\n                )\n            else:\n                transformed_dataset.loc[subset].to_csv(\n                    dir_experiment / (fn_transformed[:-3] + \"csv\"), mode=\"a\", index=False, header=False\n                )\n    with open(dir_experiment / fn_metatransformer, \"wb\") as f:\n        pickle.dump(metatransformer, f)\n    with open(dir_experiment / fn_sdv_metadata, \"wb\") as f:\n        pickle.dump(metatransformer.get_sdv_metadata(), f)\n\n    return fn_dataset\n
"},{"location":"reference/modules/dataloader/metadata/","title":"metadata","text":""},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData","title":"MetaData","text":"Source code in src/nhssynth/modules/dataloader/metadata.py
class MetaData:\n    class ColumnMetaData:\n        def __init__(self, name: str, data: pd.Series, raw: dict) -> None:\n            self.name = name\n            self.dtype: np.dtype = self._validate_dtype(data, raw.get(\"dtype\"))\n            self.categorical: bool = self._validate_categorical(data, raw.get(\"categorical\"))\n            self.missingness_strategy: GenericMissingnessStrategy = self._validate_missingness_strategy(\n                raw.get(\"missingness\")\n            )\n            self.transformer: ColumnTransformer = self._validate_transformer(raw.get(\"transformer\"))\n\n        def _validate_dtype(self, data: pd.Series, dtype_raw: Optional[Union[dict, str]] = None) -> np.dtype:\n            if isinstance(dtype_raw, dict):\n                dtype_name = dtype_raw.pop(\"name\", None)\n            elif isinstance(dtype_raw, str):\n                dtype_name = dtype_raw\n            else:\n                dtype_name = self._infer_dtype(data)\n            try:\n                dtype = np.dtype(dtype_name)\n            except TypeError:\n                warnings.warn(\n                    f\"Invalid dtype specification '{dtype_name}' for column '{self.name}', ignoring dtype for this column\"\n                )\n                dtype = self._infer_dtype(data)\n            if dtype.kind == \"M\":\n                self._setup_datetime_config(data, dtype_raw)\n            elif dtype.kind in [\"f\", \"i\", \"u\"]:\n                self.rounding_scheme = self._validate_rounding_scheme(data, dtype, dtype_raw)\n            return dtype\n\n        def _infer_dtype(self, data: pd.Series) -> np.dtype:\n            return data.dtype.name\n\n        def _infer_datetime_format(self, data: pd.Series) -> str:\n            return _guess_datetime_format_for_array(data[data.notna()].astype(str).to_numpy())\n\n        def _setup_datetime_config(self, data: pd.Series, datetime_config: dict) -> dict:\n            \"\"\"\n            Add keys to `datetime_config` corresponding to args from the `pd.to_datetime` function\n            (see [the docs](https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html))\n            \"\"\"\n            if not isinstance(datetime_config, dict):\n                datetime_config = {}\n            else:\n                datetime_config = filter_dict(datetime_config, {\"format\", \"floor\"}, include=True)\n            if \"format\" not in datetime_config:\n                datetime_config[\"format\"] = self._infer_datetime_format(data)\n            self.datetime_config = datetime_config\n\n        def _validate_rounding_scheme(self, data: pd.Series, dtype: np.dtype, dtype_dict: dict) -> int:\n            if dtype_dict and \"rounding_scheme\" in dtype_dict:\n                return dtype_dict[\"rounding_scheme\"]\n            else:\n                if dtype.kind != \"f\":\n                    return 1.0\n                roundable_data = data[data.notna()]\n                for i in range(np.finfo(dtype).precision):\n                    if (roundable_data.round(i) == roundable_data).all():\n                        return 10**-i\n            return None\n\n        def _validate_categorical(self, data: pd.Series, categorical: Optional[bool] = None) -> bool:\n            if categorical is None:\n                return self._infer_categorical(data)\n            elif not isinstance(categorical, bool):\n                warnings.warn(\n                    f\"Invalid categorical '{categorical}' for column '{self.name}', ignoring categorical for this column\"\n                )\n                return self._infer_categorical(data)\n            else:\n                self.boolean = data.nunique() <= 2\n                return categorical\n\n        def _infer_categorical(self, data: pd.Series) -> bool:\n            self.boolean = data.nunique() <= 2\n            return data.nunique() <= 10 or self.dtype.kind == \"O\"\n\n        def _validate_missingness_strategy(self, missingness_strategy: Optional[Union[dict, str]]) -> tuple[str, dict]:\n            if not missingness_strategy:\n                return None\n            if isinstance(missingness_strategy, dict):\n                impute = missingness_strategy.get(\"impute\", None)\n                strategy = \"impute\" if impute else missingness_strategy.get(\"strategy\", None)\n            else:\n                strategy = missingness_strategy\n            if (\n                strategy not in MISSINGNESS_STRATEGIES\n                or (strategy == \"impute\" and impute == \"mean\" and self.dtype.kind != \"f\")\n                or (strategy == \"impute\" and not impute)\n            ):\n                warnings.warn(\n                    f\"Invalid missingness strategy '{missingness_strategy}' for column '{self.name}', ignoring missingness strategy for this column\"\n                )\n                return None\n            return (\n                MISSINGNESS_STRATEGIES[strategy](impute) if strategy == \"impute\" else MISSINGNESS_STRATEGIES[strategy]()\n            )\n\n        def _validate_transformer(self, transformer: Optional[Union[dict, str]] = {}) -> tuple[str, dict]:\n            # if transformer is neither a dict nor a str statement below will raise a TypeError\n            if isinstance(transformer, dict):\n                self.transformer_name = transformer.get(\"name\")\n                self.transformer_config = filter_dict(transformer, \"name\")\n            elif isinstance(transformer, str):\n                self.transformer_name = transformer\n                self.transformer_config = {}\n            else:\n                if transformer is not None:\n                    warnings.warn(\n                        f\"Invalid transformer config '{transformer}' for column '{self.name}', ignoring transformer for this column\"\n                    )\n                self.transformer_name = None\n                self.transformer_config = {}\n            if not self.transformer_name:\n                return self._infer_transformer()\n            else:\n                try:\n                    return eval(self.transformer_name)(**self.transformer_config)\n                except NameError:\n                    warnings.warn(\n                        f\"Invalid transformer '{self.transformer_name}' or config '{self.transformer_config}' for column '{self.name}', ignoring transformer for this column\"\n                    )\n                    return self._infer_transformer()\n\n        def _infer_transformer(self) -> ColumnTransformer:\n            if self.categorical:\n                transformer = OHECategoricalTransformer(**self.transformer_config)\n            else:\n                transformer = ClusterContinuousTransformer(**self.transformer_config)\n            if self.dtype.kind == \"M\":\n                transformer = DatetimeTransformer(transformer)\n            return transformer\n\n    def __init__(self, data: pd.DataFrame, metadata: Optional[dict] = {}):\n        self.columns: pd.Index = data.columns\n        self.raw_metadata: dict = metadata\n        if set(self.raw_metadata[\"columns\"].keys()) - set(self.columns):\n            raise ValueError(\"Metadata contains keys that do not appear amongst the columns.\")\n        self.dropped_columns = [cn for cn in self.columns if self.raw_metadata[\"columns\"].get(cn, None) == \"drop\"]\n        self.columns = self.columns.drop(self.dropped_columns)\n        self._metadata = {\n            cn: self.ColumnMetaData(cn, data[cn], self.raw_metadata[\"columns\"].get(cn, {})) for cn in self.columns\n        }\n        self.constraints = ConstraintGraph(self.raw_metadata.get(\"constraints\", []), self.columns, self._metadata)\n\n    def __getitem__(self, key: str) -> dict[str, Any]:\n        return self._metadata[key]\n\n    def __iter__(self) -> Iterator:\n        return iter(self._metadata.values())\n\n    def __repr__(self) -> None:\n        return yaml.dump(self._metadata, default_flow_style=False, sort_keys=False)\n\n    @classmethod\n    def from_path(cls, data: pd.DataFrame, path_str: str):\n        \"\"\"\n        Instantiate a MetaData object from a YAML file via a specified path.\n\n        Args:\n            data: The data to be used to infer / validate the metadata.\n            path_str: The path to the metadata YAML file.\n\n        Returns:\n            The metadata object.\n        \"\"\"\n        path = pathlib.Path(path_str)\n        if path.exists():\n            with open(path) as stream:\n                metadata = yaml.safe_load(stream)\n            # Filter out the expanded alias/anchor group as it is not needed\n            metadata = filter_dict(metadata, {\"column_types\"})\n        else:\n            warnings.warn(f\"No metadata found at {path}...\")\n            metadata = {\"columns\": {}}\n        return cls(data, metadata)\n\n    def _collapse(self, metadata: dict) -> dict:\n        \"\"\"\n        Given a metadata dictionary, rewrite to collapse duplicate column types in order to leverage YAML anchors and shrink the file.\n\n        Args:\n            metadata: The metadata dictionary to be rewritten.\n\n        Returns:\n            A rewritten metadata dictionary with collapsed column types and transformers.\n                The returned dictionary has the following structure:\n                {\n                    \"column_types\": dict,\n                    **metadata  # one entry for each column in \"columns\" that now reference the dicts above\n                }\n                - \"column_types\" is a dictionary mapping column type indices to column type configurations.\n                - \"**metadata\" contains the original metadata dictionary, with column types rewritten to use the indices and \"column_types\".\n        \"\"\"\n        c_index = 1\n        column_types = {}\n        column_type_counts = {}\n        for cn, cd in metadata[\"columns\"].items():\n            if cd not in column_types.values():\n                column_types[c_index] = cd if isinstance(cd, str) else cd.copy()\n                column_type_counts[c_index] = 1\n                c_index += 1\n            else:\n                cix = get_key_by_value(column_types, cd)\n                column_type_counts[cix] += 1\n\n        for cn, cd in metadata[\"columns\"].items():\n            cix = get_key_by_value(column_types, cd)\n            if column_type_counts[cix] > 1:\n                metadata[\"columns\"][cn] = column_types[cix]\n            else:\n                column_types.pop(cix)\n\n        return {\"column_types\": {i + 1: x for i, x in enumerate(column_types.values())}, **metadata}\n\n    def _assemble(self, collapse_yaml: bool) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Rearrange the metadata into a dictionary that can be written to a YAML file.\n\n        Args:\n            collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n\n        Returns:\n            A dictionary containing the assembled metadata.\n        \"\"\"\n        assembled_metadata = {\n            \"columns\": {\n                cn: {\n                    \"dtype\": cmd.dtype.name\n                    if not hasattr(cmd, \"datetime_config\")\n                    else {\"name\": cmd.dtype.name, **cmd.datetime_config},\n                    \"categorical\": cmd.categorical,\n                }\n                for cn, cmd in self._metadata.items()\n            }\n        }\n        # We loop through the base dict above to add other parts if they are present in the metadata\n        for cn, cmd in self._metadata.items():\n            if cmd.missingness_strategy:\n                assembled_metadata[\"columns\"][cn][\"missingness\"] = (\n                    cmd.missingness_strategy.name\n                    if cmd.missingness_strategy.name != \"impute\"\n                    else {\"name\": cmd.missingness_strategy.name, \"impute\": cmd.missingness_strategy.impute}\n                )\n            if cmd.transformer_config:\n                assembled_metadata[\"columns\"][cn][\"transformer\"] = {\n                    **cmd.transformer_config,\n                    \"name\": cmd.transformer.__class__.__name__,\n                }\n\n        # Add back the dropped_columns not present in the metadata\n        if self.dropped_columns:\n            assembled_metadata[\"columns\"].update({cn: \"drop\" for cn in self.dropped_columns})\n\n        if collapse_yaml:\n            assembled_metadata = self._collapse(assembled_metadata)\n\n        # We add the constraints section after all of the formatting and processing above\n        # In general, the constraints are kept the same as the input (provided they passed validation)\n        # If `collapse_yaml` is specified, we output the minimum set of equivalent constraints\n        if self.constraints:\n            assembled_metadata[\"constraints\"] = (\n                [str(c) for c in self.constraints.minimal_constraints]\n                if collapse_yaml\n                else self.constraints.raw_constraint_strings\n            )\n        return assembled_metadata\n\n    def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:\n        \"\"\"\n        Writes metadata to a YAML file.\n\n        Args:\n            path: The path at which to write the metadata YAML file.\n            collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n        \"\"\"\n        with open(path, \"w\") as yaml_file:\n            yaml.safe_dump(\n                self._assemble(collapse_yaml),\n                yaml_file,\n                default_flow_style=False,\n                sort_keys=False,\n            )\n\n    def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:\n        \"\"\"\n        Map combinations of our metadata implementation to SDV's as required by SDMetrics.\n\n        Returns:\n            A dictionary containing the SDV metadata.\n        \"\"\"\n        sdv_metadata = {\n            \"columns\": {\n                cn: {\n                    \"sdtype\": \"boolean\"\n                    if cmd.boolean\n                    else \"categorical\"\n                    if cmd.categorical\n                    else \"datetime\"\n                    if cmd.dtype.kind == \"M\"\n                    else \"numerical\",\n                }\n                for cn, cmd in self._metadata.items()\n            }\n        }\n        for cn, cmd in self._metadata.items():\n            if cmd.dtype.kind == \"M\":\n                sdv_metadata[\"columns\"][cn][\"format\"] = cmd.datetime_config[\"format\"]\n        return sdv_metadata\n\n    def save_constraint_graphs(self, path: pathlib.Path) -> None:\n        \"\"\"\n        Output the constraint graphs as HTML files.\n\n        Args:\n            path: The path at which to write the constraint graph HTML files.\n        \"\"\"\n        self.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.ColumnMetaData","title":"ColumnMetaData","text":"Source code in src/nhssynth/modules/dataloader/metadata.py
class ColumnMetaData:\n    def __init__(self, name: str, data: pd.Series, raw: dict) -> None:\n        self.name = name\n        self.dtype: np.dtype = self._validate_dtype(data, raw.get(\"dtype\"))\n        self.categorical: bool = self._validate_categorical(data, raw.get(\"categorical\"))\n        self.missingness_strategy: GenericMissingnessStrategy = self._validate_missingness_strategy(\n            raw.get(\"missingness\")\n        )\n        self.transformer: ColumnTransformer = self._validate_transformer(raw.get(\"transformer\"))\n\n    def _validate_dtype(self, data: pd.Series, dtype_raw: Optional[Union[dict, str]] = None) -> np.dtype:\n        if isinstance(dtype_raw, dict):\n            dtype_name = dtype_raw.pop(\"name\", None)\n        elif isinstance(dtype_raw, str):\n            dtype_name = dtype_raw\n        else:\n            dtype_name = self._infer_dtype(data)\n        try:\n            dtype = np.dtype(dtype_name)\n        except TypeError:\n            warnings.warn(\n                f\"Invalid dtype specification '{dtype_name}' for column '{self.name}', ignoring dtype for this column\"\n            )\n            dtype = self._infer_dtype(data)\n        if dtype.kind == \"M\":\n            self._setup_datetime_config(data, dtype_raw)\n        elif dtype.kind in [\"f\", \"i\", \"u\"]:\n            self.rounding_scheme = self._validate_rounding_scheme(data, dtype, dtype_raw)\n        return dtype\n\n    def _infer_dtype(self, data: pd.Series) -> np.dtype:\n        return data.dtype.name\n\n    def _infer_datetime_format(self, data: pd.Series) -> str:\n        return _guess_datetime_format_for_array(data[data.notna()].astype(str).to_numpy())\n\n    def _setup_datetime_config(self, data: pd.Series, datetime_config: dict) -> dict:\n        \"\"\"\n        Add keys to `datetime_config` corresponding to args from the `pd.to_datetime` function\n        (see [the docs](https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html))\n        \"\"\"\n        if not isinstance(datetime_config, dict):\n            datetime_config = {}\n        else:\n            datetime_config = filter_dict(datetime_config, {\"format\", \"floor\"}, include=True)\n        if \"format\" not in datetime_config:\n            datetime_config[\"format\"] = self._infer_datetime_format(data)\n        self.datetime_config = datetime_config\n\n    def _validate_rounding_scheme(self, data: pd.Series, dtype: np.dtype, dtype_dict: dict) -> int:\n        if dtype_dict and \"rounding_scheme\" in dtype_dict:\n            return dtype_dict[\"rounding_scheme\"]\n        else:\n            if dtype.kind != \"f\":\n                return 1.0\n            roundable_data = data[data.notna()]\n            for i in range(np.finfo(dtype).precision):\n                if (roundable_data.round(i) == roundable_data).all():\n                    return 10**-i\n        return None\n\n    def _validate_categorical(self, data: pd.Series, categorical: Optional[bool] = None) -> bool:\n        if categorical is None:\n            return self._infer_categorical(data)\n        elif not isinstance(categorical, bool):\n            warnings.warn(\n                f\"Invalid categorical '{categorical}' for column '{self.name}', ignoring categorical for this column\"\n            )\n            return self._infer_categorical(data)\n        else:\n            self.boolean = data.nunique() <= 2\n            return categorical\n\n    def _infer_categorical(self, data: pd.Series) -> bool:\n        self.boolean = data.nunique() <= 2\n        return data.nunique() <= 10 or self.dtype.kind == \"O\"\n\n    def _validate_missingness_strategy(self, missingness_strategy: Optional[Union[dict, str]]) -> tuple[str, dict]:\n        if not missingness_strategy:\n            return None\n        if isinstance(missingness_strategy, dict):\n            impute = missingness_strategy.get(\"impute\", None)\n            strategy = \"impute\" if impute else missingness_strategy.get(\"strategy\", None)\n        else:\n            strategy = missingness_strategy\n        if (\n            strategy not in MISSINGNESS_STRATEGIES\n            or (strategy == \"impute\" and impute == \"mean\" and self.dtype.kind != \"f\")\n            or (strategy == \"impute\" and not impute)\n        ):\n            warnings.warn(\n                f\"Invalid missingness strategy '{missingness_strategy}' for column '{self.name}', ignoring missingness strategy for this column\"\n            )\n            return None\n        return (\n            MISSINGNESS_STRATEGIES[strategy](impute) if strategy == \"impute\" else MISSINGNESS_STRATEGIES[strategy]()\n        )\n\n    def _validate_transformer(self, transformer: Optional[Union[dict, str]] = {}) -> tuple[str, dict]:\n        # if transformer is neither a dict nor a str statement below will raise a TypeError\n        if isinstance(transformer, dict):\n            self.transformer_name = transformer.get(\"name\")\n            self.transformer_config = filter_dict(transformer, \"name\")\n        elif isinstance(transformer, str):\n            self.transformer_name = transformer\n            self.transformer_config = {}\n        else:\n            if transformer is not None:\n                warnings.warn(\n                    f\"Invalid transformer config '{transformer}' for column '{self.name}', ignoring transformer for this column\"\n                )\n            self.transformer_name = None\n            self.transformer_config = {}\n        if not self.transformer_name:\n            return self._infer_transformer()\n        else:\n            try:\n                return eval(self.transformer_name)(**self.transformer_config)\n            except NameError:\n                warnings.warn(\n                    f\"Invalid transformer '{self.transformer_name}' or config '{self.transformer_config}' for column '{self.name}', ignoring transformer for this column\"\n                )\n                return self._infer_transformer()\n\n    def _infer_transformer(self) -> ColumnTransformer:\n        if self.categorical:\n            transformer = OHECategoricalTransformer(**self.transformer_config)\n        else:\n            transformer = ClusterContinuousTransformer(**self.transformer_config)\n        if self.dtype.kind == \"M\":\n            transformer = DatetimeTransformer(transformer)\n        return transformer\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.from_path","title":"from_path(data, path_str) classmethod","text":"

Instantiate a MetaData object from a YAML file via a specified path.

Parameters:

Name Type Description Default data DataFrame

The data to be used to infer / validate the metadata.

required path_str str

The path to the metadata YAML file.

required

Returns:

Type Description

The metadata object.

Source code in src/nhssynth/modules/dataloader/metadata.py
@classmethod\ndef from_path(cls, data: pd.DataFrame, path_str: str):\n    \"\"\"\n    Instantiate a MetaData object from a YAML file via a specified path.\n\n    Args:\n        data: The data to be used to infer / validate the metadata.\n        path_str: The path to the metadata YAML file.\n\n    Returns:\n        The metadata object.\n    \"\"\"\n    path = pathlib.Path(path_str)\n    if path.exists():\n        with open(path) as stream:\n            metadata = yaml.safe_load(stream)\n        # Filter out the expanded alias/anchor group as it is not needed\n        metadata = filter_dict(metadata, {\"column_types\"})\n    else:\n        warnings.warn(f\"No metadata found at {path}...\")\n        metadata = {\"columns\": {}}\n    return cls(data, metadata)\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.get_sdv_metadata","title":"get_sdv_metadata()","text":"

Map combinations of our metadata implementation to SDV's as required by SDMetrics.

Returns:

Type Description dict[str, dict[str, dict[str, str]]]

A dictionary containing the SDV metadata.

Source code in src/nhssynth/modules/dataloader/metadata.py
def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:\n    \"\"\"\n    Map combinations of our metadata implementation to SDV's as required by SDMetrics.\n\n    Returns:\n        A dictionary containing the SDV metadata.\n    \"\"\"\n    sdv_metadata = {\n        \"columns\": {\n            cn: {\n                \"sdtype\": \"boolean\"\n                if cmd.boolean\n                else \"categorical\"\n                if cmd.categorical\n                else \"datetime\"\n                if cmd.dtype.kind == \"M\"\n                else \"numerical\",\n            }\n            for cn, cmd in self._metadata.items()\n        }\n    }\n    for cn, cmd in self._metadata.items():\n        if cmd.dtype.kind == \"M\":\n            sdv_metadata[\"columns\"][cn][\"format\"] = cmd.datetime_config[\"format\"]\n    return sdv_metadata\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.save","title":"save(path, collapse_yaml)","text":"

Writes metadata to a YAML file.

Parameters:

Name Type Description Default path Path

The path at which to write the metadata YAML file.

required collapse_yaml bool

A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.

required Source code in src/nhssynth/modules/dataloader/metadata.py
def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:\n    \"\"\"\n    Writes metadata to a YAML file.\n\n    Args:\n        path: The path at which to write the metadata YAML file.\n        collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n    \"\"\"\n    with open(path, \"w\") as yaml_file:\n        yaml.safe_dump(\n            self._assemble(collapse_yaml),\n            yaml_file,\n            default_flow_style=False,\n            sort_keys=False,\n        )\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.save_constraint_graphs","title":"save_constraint_graphs(path)","text":"

Output the constraint graphs as HTML files.

Parameters:

Name Type Description Default path Path

The path at which to write the constraint graph HTML files.

required Source code in src/nhssynth/modules/dataloader/metadata.py
def save_constraint_graphs(self, path: pathlib.Path) -> None:\n    \"\"\"\n    Output the constraint graphs as HTML files.\n\n    Args:\n        path: The path at which to write the constraint graph HTML files.\n    \"\"\"\n    self.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metatransformer/","title":"metatransformer","text":""},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer","title":"MetaTransformer","text":"

The metatransformer is responsible for transforming input dataset into a format that can be used by the model module, and for transforming this module's output back to the original format of the input dataset.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata Optional[MetaData]

Optionally, a MetaData object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.

None missingness_strategy Optional[str]

The missingness strategy to use. Defaults to augmenting missing values in the data, see the missingness strategies for more information.

'augment' impute_value Optional[Any]

Only used when missingness_strategy is set to 'impute'. The value to use when imputing missing values in the data.

None

After calling MetaTransformer.apply(), the following attributes and methods will be available:

Attributes:

Name Type Description typed_dataset DataFrame

The dataset with the dtypes applied.

post_missingness_strategy_dataset DataFrame

The dataset with the missingness strategies applied.

transformed_dataset DataFrame

The transformed dataset.

single_column_indices list[int]

The indices of the columns that were transformed into a single column.

multi_column_indices list[list[int]]

The indices of the columns that were transformed into multiple columns.

Methods:

  • get_typed_dataset(): Returns the typed dataset.
  • get_prepared_dataset(): Returns the dataset with the missingness strategies applied.
  • get_transformed_dataset(): Returns the transformed dataset.
  • get_multi_and_single_column_indices(): Returns the indices of the columns that were transformed into one or multiple column(s).
  • get_sdv_metadata(): Returns the metadata in the correct format for SDMetrics.
  • save_metadata(): Saves the metadata to a file.
  • save_constraint_graphs(): Saves the constraint graphs to a file.

Note that mt.apply is a helper function that runs mt.apply_dtypes, mt.apply_missingness_strategy and mt.transform in sequence. This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
class MetaTransformer:\n    \"\"\"\n    The metatransformer is responsible for transforming input dataset into a format that can be used by the `model` module, and for transforming\n    this module's output back to the original format of the input dataset.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata: Optionally, a [`MetaData`][nhssynth.modules.dataloader.metadata.MetaData] object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.\n        missingness_strategy: The missingness strategy to use. Defaults to augmenting missing values in the data, see [the missingness strategies][nhssynth.modules.dataloader.missingness] for more information.\n        impute_value: Only used when `missingness_strategy` is set to 'impute'. The value to use when imputing missing values in the data.\n\n    After calling `MetaTransformer.apply()`, the following attributes and methods will be available:\n\n    Attributes:\n        typed_dataset (pd.DataFrame): The dataset with the dtypes applied.\n        post_missingness_strategy_dataset (pd.DataFrame): The dataset with the missingness strategies applied.\n        transformed_dataset (pd.DataFrame): The transformed dataset.\n        single_column_indices (list[int]): The indices of the columns that were transformed into a single column.\n        multi_column_indices (list[list[int]]): The indices of the columns that were transformed into multiple columns.\n\n    **Methods:**\n\n    - `get_typed_dataset()`: Returns the typed dataset.\n    - `get_prepared_dataset()`: Returns the dataset with the missingness strategies applied.\n    - `get_transformed_dataset()`: Returns the transformed dataset.\n    - `get_multi_and_single_column_indices()`: Returns the indices of the columns that were transformed into one or multiple column(s).\n    - `get_sdv_metadata()`: Returns the metadata in the correct format for SDMetrics.\n    - `save_metadata()`: Saves the metadata to a file.\n    - `save_constraint_graphs()`: Saves the constraint graphs to a file.\n\n    Note that `mt.apply` is a helper function that runs `mt.apply_dtypes`, `mt.apply_missingness_strategy` and `mt.transform` in sequence.\n    This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: pd.DataFrame,\n        metadata: Optional[MetaData] = None,\n        missingness_strategy: Optional[str] = \"augment\",\n        impute_value: Optional[Any] = None,\n    ):\n        self._raw_dataset: pd.DataFrame = dataset\n        self._metadata: MetaData = metadata or MetaData(dataset)\n        if missingness_strategy == \"impute\":\n            assert (\n                impute_value is not None\n            ), \"`impute_value` of the `MetaTransformer` must be specified (via the --impute flag) when using the imputation missingness strategy\"\n            self._impute_value = impute_value\n        self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy]\n\n    @classmethod\n    def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:\n        \"\"\"\n        Instantiates a MetaTransformer from a metadata file via a provided path.\n\n        Args:\n            dataset: The raw input DataFrame.\n            metadata_path: The path to the metadata file.\n\n        Returns:\n            A MetaTransformer object.\n        \"\"\"\n        return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)\n\n    @classmethod\n    def from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:\n        \"\"\"\n        Instantiates a MetaTransformer from a metadata dictionary.\n\n        Args:\n            dataset: The raw input DataFrame.\n            metadata: A dictionary of raw metadata.\n\n        Returns:\n            A MetaTransformer object.\n        \"\"\"\n        return cls(dataset, MetaData(dataset, metadata), **kwargs)\n\n    def drop_columns(self) -> None:\n        \"\"\"\n        Drops columns from the dataset that are not in the `MetaData`.\n        \"\"\"\n        self._raw_dataset = self._raw_dataset[self._metadata.columns]\n\n    def _apply_rounding_scheme(self, working_column: pd.Series, rounding_scheme: float) -> pd.Series:\n        \"\"\"\n        A rounding scheme takes the form of the smallest value that should be rounded to 0, i.e. 0.01 for 2dp.\n        We first round to the nearest multiple in the standard way, through dividing, rounding and then multiplying.\n        However, this can lead to floating point errors, so we then round to the number of decimal places required by the rounding scheme.\n\n        e.g. `np.round(0.15 / 0.1) * 0.1` will erroneously return 0.1.\n\n        Args:\n            working_column: The column to apply the rounding scheme to.\n            rounding_scheme: The rounding scheme to apply.\n\n        Returns:\n            The column with the rounding scheme applied.\n        \"\"\"\n        working_column = np.round(working_column / rounding_scheme) * rounding_scheme\n        return working_column.round(max(0, int(np.ceil(np.log10(1 / rounding_scheme)))))\n\n    def _apply_dtype(\n        self,\n        working_column: pd.Series,\n        column_metadata: MetaData.ColumnMetaData,\n    ) -> pd.Series:\n        \"\"\"\n        Given a `working_column`, the dtype specified in the `column_metadata` is applied to it.\n         - Datetime columns are floored, and their format is inferred.\n         - Rounding schemes are applied to numeric columns if specified.\n         - Columns with missing values have their dtype converted to the pandas equivalent to allow for NA values.\n\n        Args:\n            working_column: The column to apply the dtype to.\n            column_metadata: The metadata for the column.\n\n        Returns:\n            The column with the dtype applied.\n        \"\"\"\n        dtype = column_metadata.dtype\n        try:\n            if dtype.kind == \"M\":\n                working_column = pd.to_datetime(working_column, format=column_metadata.datetime_config.get(\"format\"))\n                if column_metadata.datetime_config.get(\"floor\"):\n                    working_column = working_column.dt.floor(column_metadata.datetime_config.get(\"floor\"))\n                    column_metadata.datetime_config[\"format\"] = column_metadata._infer_datetime_format(working_column)\n                return working_column\n            else:\n                if hasattr(column_metadata, \"rounding_scheme\") and column_metadata.rounding_scheme is not None:\n                    working_column = self._apply_rounding_scheme(working_column, column_metadata.rounding_scheme)\n                # If there are missing values in the column, we need to use the pandas equivalent of the dtype to allow for NA values\n                if working_column.isnull().any() and dtype.kind in [\"i\", \"u\", \"f\"]:\n                    return working_column.astype(dtype.name.capitalize())\n                else:\n                    return working_column.astype(dtype)\n        except ValueError:\n            raise ValueError(f\"{sys.exc_info()[1]}\\nError applying dtype '{dtype}' to column '{working_column.name}'\")\n\n    def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Applies dtypes from the metadata to `dataset`.\n\n        Returns:\n            The dataset with the dtypes applied.\n        \"\"\"\n        working_data = data.copy()\n        for column_metadata in self._metadata:\n            working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)\n        return working_data\n\n    def apply_missingness_strategy(self) -> pd.DataFrame:\n        \"\"\"\n        Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or\n        column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness\n        is not resolved, instead a new column / value is added for later transformation.\n\n        Returns:\n            The dataset with the missingness strategies applied.\n        \"\"\"\n        working_data = self.typed_dataset.copy()\n        for column_metadata in self._metadata:\n            if not column_metadata.missingness_strategy:\n                column_metadata.missingness_strategy = (\n                    self._missingness_strategy(self._impute_value)\n                    if hasattr(self, \"_impute_value\")\n                    else self._missingness_strategy()\n                )\n            if not working_data[column_metadata.name].isnull().any():\n                continue\n            working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)\n        return working_data\n\n    # def apply_constraints(self) -> pd.DataFrame:\n    #     working_data = self.post_missingness_strategy_dataset.copy()\n    #     for constraint in self._metadata.constraints:\n    #         working_data = constraint.apply(working_data)\n    #     return working_data\n\n    def _get_missingness_carrier(self, column_metadata: MetaData.ColumnMetaData) -> Union[pd.Series, Any]:\n        \"\"\"\n        In the case of the `AugmentMissingnessStrategy`, a `missingness_carrier` has been determined for each column.\n        For continuous columns this is an indicator column for the presence of NaN values.\n        For categorical columns this is the value to be used to represent missingness as a category.\n\n        Args:\n            column_metadata: The metadata for the column.\n\n        Returns:\n            The missingness carrier for the column.\n        \"\"\"\n        missingness_carrier = getattr(column_metadata.missingness_strategy, \"missingness_carrier\", None)\n        if missingness_carrier in self.post_missingness_strategy_dataset.columns:\n            return self.post_missingness_strategy_dataset[missingness_carrier]\n        else:\n            return missingness_carrier\n\n    def transform(self) -> pd.DataFrame:\n        \"\"\"\n        Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.\n\n        Returns:\n            The transformed dataset.\n        \"\"\"\n        transformed_columns = []\n        self.single_column_indices = []\n        self.multi_column_indices = []\n        col_counter = 0\n        working_data = self.post_missingness_strategy_dataset.copy()\n\n        # iteratively build the transformed df\n        for column_metadata in tqdm(\n            self._metadata, desc=\"Transforming data\", unit=\"column\", total=len(self._metadata.columns)\n        ):\n            missingness_carrier = self._get_missingness_carrier(column_metadata)\n            transformed_data = column_metadata.transformer.apply(\n                working_data[column_metadata.name], missingness_carrier\n            )\n            transformed_columns.append(transformed_data)\n\n            # track single and multi column indices to supply to the model\n            if isinstance(transformed_data, pd.DataFrame) and transformed_data.shape[1] > 1:\n                num_to_add = transformed_data.shape[1]\n                if not column_metadata.categorical:\n                    self.single_column_indices.append(col_counter)\n                    col_counter += 1\n                    num_to_add -= 1\n                self.multi_column_indices.append(list(range(col_counter, col_counter + num_to_add)))\n                col_counter += num_to_add\n            else:\n                self.single_column_indices.append(col_counter)\n                col_counter += 1\n\n        return pd.concat(transformed_columns, axis=1)\n\n    def apply(self) -> pd.DataFrame:\n        \"\"\"\n        Applies the various steps of the MetaTransformer to a passed DataFrame.\n\n        Returns:\n            The transformed dataset.\n        \"\"\"\n        self.drop_columns()\n        self.typed_dataset = self.apply_dtypes(self._raw_dataset)\n        self.post_missingness_strategy_dataset = self.apply_missingness_strategy()\n        # self.constrained_dataset = self.apply_constraints()\n        self.transformed_dataset = self.transform()\n        return self.transformed_dataset\n\n    def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Reverses the transformation applied by the MetaTransformer.\n\n        Args:\n            dataset: The transformed dataset.\n\n        Returns:\n            The original dataset.\n        \"\"\"\n        for column_metadata in self._metadata:\n            dataset = column_metadata.transformer.revert(dataset)\n        return self.apply_dtypes(dataset)\n\n    def get_typed_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"typed_dataset\"):\n            raise ValueError(\n                \"The typed dataset has not yet been created. Call `mt.apply()` (or `mt.apply_dtypes()`) first.\"\n            )\n        return self.typed_dataset\n\n    def get_prepared_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"prepared_dataset\"):\n            raise ValueError(\n                \"The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.apply_missingness_strategy()`) first.\"\n            )\n        return self.prepared_dataset\n\n    def get_transformed_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"transformed_dataset\"):\n            raise ValueError(\n                \"The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n            )\n        return self.transformed_dataset\n\n    def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:\n        \"\"\"\n        Returns the indices of the columns that were transformed into one or multiple column(s).\n\n        Returns:\n            A tuple containing the indices of the single and multi columns.\n        \"\"\"\n        if not hasattr(self, \"multi_column_indices\") or not hasattr(self, \"single_column_indices\"):\n            raise ValueError(\n                \"The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n            )\n        return self.multi_column_indices, self.single_column_indices\n\n    def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.\n\n        Returns:\n            The metadata in the correct format for SDMetrics.\n        \"\"\"\n        return self._metadata.get_sdv_metadata()\n\n    def save_metadata(self, path: pathlib.Path, collapse_yaml: bool = False) -> None:\n        return self._metadata.save(path, collapse_yaml)\n\n    def save_constraint_graphs(self, path: pathlib.Path) -> None:\n        return self._metadata.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply","title":"apply()","text":"

Applies the various steps of the MetaTransformer to a passed DataFrame.

Returns:

Type Description DataFrame

The transformed dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply(self) -> pd.DataFrame:\n    \"\"\"\n    Applies the various steps of the MetaTransformer to a passed DataFrame.\n\n    Returns:\n        The transformed dataset.\n    \"\"\"\n    self.drop_columns()\n    self.typed_dataset = self.apply_dtypes(self._raw_dataset)\n    self.post_missingness_strategy_dataset = self.apply_missingness_strategy()\n    # self.constrained_dataset = self.apply_constraints()\n    self.transformed_dataset = self.transform()\n    return self.transformed_dataset\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply_dtypes","title":"apply_dtypes(data)","text":"

Applies dtypes from the metadata to dataset.

Returns:

Type Description DataFrame

The dataset with the dtypes applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Applies dtypes from the metadata to `dataset`.\n\n    Returns:\n        The dataset with the dtypes applied.\n    \"\"\"\n    working_data = data.copy()\n    for column_metadata in self._metadata:\n        working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)\n    return working_data\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply_missingness_strategy","title":"apply_missingness_strategy()","text":"

Resolves missingness in the dataset via the MetaTransformer's global missingness strategy or column-wise missingness strategies. In the case of the AugmentMissingnessStrategy, the missingness is not resolved, instead a new column / value is added for later transformation.

Returns:

Type Description DataFrame

The dataset with the missingness strategies applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_missingness_strategy(self) -> pd.DataFrame:\n    \"\"\"\n    Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or\n    column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness\n    is not resolved, instead a new column / value is added for later transformation.\n\n    Returns:\n        The dataset with the missingness strategies applied.\n    \"\"\"\n    working_data = self.typed_dataset.copy()\n    for column_metadata in self._metadata:\n        if not column_metadata.missingness_strategy:\n            column_metadata.missingness_strategy = (\n                self._missingness_strategy(self._impute_value)\n                if hasattr(self, \"_impute_value\")\n                else self._missingness_strategy()\n            )\n        if not working_data[column_metadata.name].isnull().any():\n            continue\n        working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)\n    return working_data\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.drop_columns","title":"drop_columns()","text":"

Drops columns from the dataset that are not in the MetaData.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def drop_columns(self) -> None:\n    \"\"\"\n    Drops columns from the dataset that are not in the `MetaData`.\n    \"\"\"\n    self._raw_dataset = self._raw_dataset[self._metadata.columns]\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.from_dict","title":"from_dict(dataset, metadata, **kwargs) classmethod","text":"

Instantiates a MetaTransformer from a metadata dictionary.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata dict

A dictionary of raw metadata.

required

Returns:

Type Description Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod\ndef from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:\n    \"\"\"\n    Instantiates a MetaTransformer from a metadata dictionary.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata: A dictionary of raw metadata.\n\n    Returns:\n        A MetaTransformer object.\n    \"\"\"\n    return cls(dataset, MetaData(dataset, metadata), **kwargs)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.from_path","title":"from_path(dataset, metadata_path, **kwargs) classmethod","text":"

Instantiates a MetaTransformer from a metadata file via a provided path.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata_path str

The path to the metadata file.

required

Returns:

Type Description Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod\ndef from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:\n    \"\"\"\n    Instantiates a MetaTransformer from a metadata file via a provided path.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata_path: The path to the metadata file.\n\n    Returns:\n        A MetaTransformer object.\n    \"\"\"\n    return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.get_multi_and_single_column_indices","title":"get_multi_and_single_column_indices()","text":"

Returns the indices of the columns that were transformed into one or multiple column(s).

Returns:

Type Description tuple[list[int], list[int]]

A tuple containing the indices of the single and multi columns.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:\n    \"\"\"\n    Returns the indices of the columns that were transformed into one or multiple column(s).\n\n    Returns:\n        A tuple containing the indices of the single and multi columns.\n    \"\"\"\n    if not hasattr(self, \"multi_column_indices\") or not hasattr(self, \"single_column_indices\"):\n        raise ValueError(\n            \"The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n        )\n    return self.multi_column_indices, self.single_column_indices\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.get_sdv_metadata","title":"get_sdv_metadata()","text":"

Calls the MetaData method to reformat its contents into the correct format for use with SDMetrics.

Returns:

Type Description dict[str, dict[str, Any]]

The metadata in the correct format for SDMetrics.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:\n    \"\"\"\n    Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.\n\n    Returns:\n        The metadata in the correct format for SDMetrics.\n    \"\"\"\n    return self._metadata.get_sdv_metadata()\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.inverse_apply","title":"inverse_apply(dataset)","text":"

Reverses the transformation applied by the MetaTransformer.

Parameters:

Name Type Description Default dataset DataFrame

The transformed dataset.

required

Returns:

Type Description DataFrame

The original dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Reverses the transformation applied by the MetaTransformer.\n\n    Args:\n        dataset: The transformed dataset.\n\n    Returns:\n        The original dataset.\n    \"\"\"\n    for column_metadata in self._metadata:\n        dataset = column_metadata.transformer.revert(dataset)\n    return self.apply_dtypes(dataset)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.transform","title":"transform()","text":"

Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.

Returns:

Type Description DataFrame

The transformed dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def transform(self) -> pd.DataFrame:\n    \"\"\"\n    Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.\n\n    Returns:\n        The transformed dataset.\n    \"\"\"\n    transformed_columns = []\n    self.single_column_indices = []\n    self.multi_column_indices = []\n    col_counter = 0\n    working_data = self.post_missingness_strategy_dataset.copy()\n\n    # iteratively build the transformed df\n    for column_metadata in tqdm(\n        self._metadata, desc=\"Transforming data\", unit=\"column\", total=len(self._metadata.columns)\n    ):\n        missingness_carrier = self._get_missingness_carrier(column_metadata)\n        transformed_data = column_metadata.transformer.apply(\n            working_data[column_metadata.name], missingness_carrier\n        )\n        transformed_columns.append(transformed_data)\n\n        # track single and multi column indices to supply to the model\n        if isinstance(transformed_data, pd.DataFrame) and transformed_data.shape[1] > 1:\n            num_to_add = transformed_data.shape[1]\n            if not column_metadata.categorical:\n                self.single_column_indices.append(col_counter)\n                col_counter += 1\n                num_to_add -= 1\n            self.multi_column_indices.append(list(range(col_counter, col_counter + num_to_add)))\n            col_counter += num_to_add\n        else:\n            self.single_column_indices.append(col_counter)\n            col_counter += 1\n\n    return pd.concat(transformed_columns, axis=1)\n
"},{"location":"reference/modules/dataloader/missingness/","title":"missingness","text":""},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.AugmentMissingnessStrategy","title":"AugmentMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Source code in src/nhssynth/modules/dataloader/missingness.py
class AugmentMissingnessStrategy(GenericMissingnessStrategy):\n    def __init__(self) -> None:\n        super().__init__(\"augment\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata enabling the correct set up of the missingness strategy.\n\n        Returns:\n            The dataset, potentially with a new column representing the missingness for the column added.\n        \"\"\"\n        if column_metadata.categorical:\n            if column_metadata.dtype.kind == \"O\":\n                self.missingness_carrier = column_metadata.name + \"_missing\"\n            else:\n                self.missingness_carrier = data[column_metadata.name].min() - 1\n        else:\n            self.missingness_carrier = column_metadata.name + \"_missing\"\n            data[self.missingness_carrier] = data[column_metadata.name].isnull().astype(int)\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.AugmentMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata enabling the correct set up of the missingness strategy.

required

Returns:

Type Description DataFrame

The dataset, potentially with a new column representing the missingness for the column added.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata enabling the correct set up of the missingness strategy.\n\n    Returns:\n        The dataset, potentially with a new column representing the missingness for the column added.\n    \"\"\"\n    if column_metadata.categorical:\n        if column_metadata.dtype.kind == \"O\":\n            self.missingness_carrier = column_metadata.name + \"_missing\"\n        else:\n            self.missingness_carrier = data[column_metadata.name].min() - 1\n    else:\n        self.missingness_carrier = column_metadata.name + \"_missing\"\n        data[self.missingness_carrier] = data[column_metadata.name].isnull().astype(int)\n    return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.DropMissingnessStrategy","title":"DropMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Drop missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class DropMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Drop missingness strategy.\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__(\"drop\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Drop rows containing missing values in the appropriate column.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata.\n\n        Returns:\n            The dataset with rows containing missing values in the appropriate column dropped.\n        \"\"\"\n        return data.dropna(subset=[column_metadata.name]).reset_index(drop=True)\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.DropMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Drop rows containing missing values in the appropriate column.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata.

required

Returns:

Type Description DataFrame

The dataset with rows containing missing values in the appropriate column dropped.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Drop rows containing missing values in the appropriate column.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata.\n\n    Returns:\n        The dataset with rows containing missing values in the appropriate column dropped.\n    \"\"\"\n    return data.dropna(subset=[column_metadata.name]).reset_index(drop=True)\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.GenericMissingnessStrategy","title":"GenericMissingnessStrategy","text":"

Bases: ABC

Generic missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class GenericMissingnessStrategy(ABC):\n    \"\"\"Generic missingness strategy.\"\"\"\n\n    def __init__(self, name: str) -> None:\n        super().__init__()\n        self.name: str = name\n\n    @abstractmethod\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"Remove missingness.\"\"\"\n        pass\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.GenericMissingnessStrategy.remove","title":"remove(data, column_metadata) abstractmethod","text":"

Remove missingness.

Source code in src/nhssynth/modules/dataloader/missingness.py
@abstractmethod\ndef remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"Remove missingness.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.ImputeMissingnessStrategy","title":"ImputeMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Impute missingness with mean strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class ImputeMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Impute missingness with mean strategy.\"\"\"\n\n    def __init__(self, impute: Any) -> None:\n        super().__init__(\"impute\")\n        self.impute = impute.lower() if isinstance(impute, str) else impute\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Impute missingness in the data via the `impute` strategy. 'Special' values trigger specific behaviour.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata.\n\n        Returns:\n            The dataset with missing values in the appropriate column replaced with imputed ones.\n        \"\"\"\n        if (self.impute == \"mean\" or self.impute == \"median\") and column_metadata.categorical:\n            warnings.warn(\"Cannot impute mean or median for categorical data, using mode instead.\")\n            self.imputation_value = data[column_metadata.name].mode()[0]\n        elif self.impute == \"mean\":\n            self.imputation_value = data[column_metadata.name].mean()\n        elif self.impute == \"median\":\n            self.imputation_value = data[column_metadata.name].median()\n        elif self.impute == \"mode\":\n            self.imputation_value = data[column_metadata.name].mode()[0]\n        else:\n            self.imputation_value = self.impute\n        self.imputation_value = column_metadata.dtype.type(self.imputation_value)\n        try:\n            data[column_metadata.name].fillna(self.imputation_value, inplace=True)\n        except AssertionError:\n            raise ValueError(f\"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.\")\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.ImputeMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Impute missingness in the data via the impute strategy. 'Special' values trigger specific behaviour.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata.

required

Returns:

Type Description DataFrame

The dataset with missing values in the appropriate column replaced with imputed ones.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Impute missingness in the data via the `impute` strategy. 'Special' values trigger specific behaviour.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata.\n\n    Returns:\n        The dataset with missing values in the appropriate column replaced with imputed ones.\n    \"\"\"\n    if (self.impute == \"mean\" or self.impute == \"median\") and column_metadata.categorical:\n        warnings.warn(\"Cannot impute mean or median for categorical data, using mode instead.\")\n        self.imputation_value = data[column_metadata.name].mode()[0]\n    elif self.impute == \"mean\":\n        self.imputation_value = data[column_metadata.name].mean()\n    elif self.impute == \"median\":\n        self.imputation_value = data[column_metadata.name].median()\n    elif self.impute == \"mode\":\n        self.imputation_value = data[column_metadata.name].mode()[0]\n    else:\n        self.imputation_value = self.impute\n    self.imputation_value = column_metadata.dtype.type(self.imputation_value)\n    try:\n        data[column_metadata.name].fillna(self.imputation_value, inplace=True)\n    except AssertionError:\n        raise ValueError(f\"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.\")\n    return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.NullMissingnessStrategy","title":"NullMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Null missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class NullMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Null missingness strategy.\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__(\"none\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"Do nothing.\"\"\"\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.NullMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Do nothing.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"Do nothing.\"\"\"\n    return data\n
"},{"location":"reference/modules/dataloader/run/","title":"run","text":""},{"location":"reference/modules/dataloader/transformers/","title":"transformers","text":""},{"location":"reference/modules/dataloader/transformers/base/","title":"base","text":""},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer","title":"ColumnTransformer","text":"

Bases: ABC

A generic column transformer class to prototype all of the transformers applied via the MetaTransformer.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
class ColumnTransformer(ABC):\n    \"\"\"A generic column transformer class to prototype all of the transformers applied via the [`MetaTransformer`][nhssynth.modules.dataloader.metatransformer.MetaTransformer].\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    @abstractmethod\n    def apply(self, data: pd.DataFrame, missingness_column: Optional[pd.Series]) -> None:\n        \"\"\"Apply the transformer to the data.\"\"\"\n        pass\n\n    @abstractmethod\n    def revert(self, data: pd.DataFrame) -> None:\n        \"\"\"Revert data to pre-transformer state.\"\"\"\n        pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer.apply","title":"apply(data, missingness_column) abstractmethod","text":"

Apply the transformer to the data.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
@abstractmethod\ndef apply(self, data: pd.DataFrame, missingness_column: Optional[pd.Series]) -> None:\n    \"\"\"Apply the transformer to the data.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer.revert","title":"revert(data) abstractmethod","text":"

Revert data to pre-transformer state.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
@abstractmethod\ndef revert(self, data: pd.DataFrame) -> None:\n    \"\"\"Revert data to pre-transformer state.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper","title":"TransformerWrapper","text":"

Bases: ABC

A class to facilitate nesting of ColumnTransformers.

Parameters:

Name Type Description Default wrapped_transformer ColumnTransformer

The ColumnTransformer to wrap.

required Source code in src/nhssynth/modules/dataloader/transformers/base.py
class TransformerWrapper(ABC):\n    \"\"\"\n    A class to facilitate nesting of [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer]s.\n\n    Args:\n        wrapped_transformer: The [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer] to wrap.\n    \"\"\"\n\n    def __init__(self, wrapped_transformer: ColumnTransformer) -> None:\n        super().__init__()\n        self._wrapped_transformer: ColumnTransformer = wrapped_transformer\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series], **kwargs) -> pd.DataFrame:\n        \"\"\"Method for applying the wrapped transformer to the data.\"\"\"\n        return self._wrapped_transformer.apply(data, missingness_column, **kwargs)\n\n    def revert(self, data: pd.Series, **kwargs) -> pd.DataFrame:\n        \"\"\"Method for reverting the passed data via the wrapped transformer.\"\"\"\n        return self._wrapped_transformer.revert(data, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper.apply","title":"apply(data, missingness_column, **kwargs)","text":"

Method for applying the wrapped transformer to the data.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series], **kwargs) -> pd.DataFrame:\n    \"\"\"Method for applying the wrapped transformer to the data.\"\"\"\n    return self._wrapped_transformer.apply(data, missingness_column, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper.revert","title":"revert(data, **kwargs)","text":"

Method for reverting the passed data via the wrapped transformer.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
def revert(self, data: pd.Series, **kwargs) -> pd.DataFrame:\n    \"\"\"Method for reverting the passed data via the wrapped transformer.\"\"\"\n    return self._wrapped_transformer.revert(data, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/categorical/","title":"categorical","text":""},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer","title":"OHECategoricalTransformer","text":"

Bases: ColumnTransformer

A transformer to one-hot encode categorical features via sklearn's OneHotEncoder. Essentially wraps the fit_transformer and inverse_transform methods of OneHotEncoder to comply with the ColumnTransformer interface.

Parameters:

Name Type Description Default drop Optional[Union[list, str]]

str or list of str, to pass to OneHotEncoder's drop parameter.

None

Attributes:

Name Type Description missing_value Any

The value used to fill missing values in the data.

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description original_column_name

The name of the original column.

new_column_names

The names of the columns generated by the transformer.

Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
class OHECategoricalTransformer(ColumnTransformer):\n    \"\"\"\n    A transformer to one-hot encode categorical features via sklearn's `OneHotEncoder`.\n    Essentially wraps the `fit_transformer` and `inverse_transform` methods of `OneHotEncoder` to comply with the `ColumnTransformer` interface.\n\n    Args:\n        drop: str or list of str, to pass to `OneHotEncoder`'s `drop` parameter.\n\n    Attributes:\n        missing_value: The value used to fill missing values in the data.\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        original_column_name: The name of the original column.\n        new_column_names: The names of the columns generated by the transformer.\n    \"\"\"\n\n    def __init__(self, drop: Optional[Union[list, str]] = None) -> None:\n        super().__init__()\n        self._drop: Union[list, str] = drop\n        self._transformer: OneHotEncoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=self._drop)\n        self.missing_value: Any = None\n\n    def apply(self, data: pd.Series, missing_value: Optional[Any] = None) -> pd.DataFrame:\n        \"\"\"\n        Apply the transformer to the data via sklearn's `OneHotEncoder`'s `fit_transform` method. Name the new columns via manipulation of the original column name.\n        If `missing_value` is provided, fill missing values with this value before applying the transformer to ensure a new category is added.\n\n        Args:\n            data: The column of data to transform.\n            missing_value: The value learned by the `MetaTransformer` to represent missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n        \"\"\"\n        self.original_column_name = data.name\n        if missing_value:\n            data = data.fillna(missing_value)\n            self.missing_value = missing_value\n        transformed_data = pd.DataFrame(\n            self._transformer.fit_transform(data.values.reshape(-1, 1)),\n            columns=self._transformer.get_feature_names_out(input_features=[data.name]),\n        )\n        self.new_column_names = transformed_data.columns\n        return transformed_data\n\n    def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Revert data to pre-transformer state via sklearn's `OneHotEncoder`'s `inverse_transform` method.\n        If `missing_value` is provided, replace instances of this value in the data with `np.nan` to ensure missing values are represented correctly in the case\n        where `missing_value` was 'modelled' and thus generated.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.\n        \"\"\"\n        data[self.original_column_name] = pd.Series(\n            self._transformer.inverse_transform(data[self.new_column_names].values).flatten(),\n            index=data.index,\n            name=self.original_column_name,\n        )\n        if self.missing_value:\n            data[self.original_column_name] = data[self.original_column_name].replace(self.missing_value, np.nan)\n        return data.drop(self.new_column_names, axis=1)\n
"},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer.apply","title":"apply(data, missing_value=None)","text":"

Apply the transformer to the data via sklearn's OneHotEncoder's fit_transform method. Name the new columns via manipulation of the original column name. If missing_value is provided, fill missing values with this value before applying the transformer to ensure a new category is added.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missing_value Optional[Any]

The value learned by the MetaTransformer to represent missingness, this is only used as part of the AugmentMissingnessStrategy.

None Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
def apply(self, data: pd.Series, missing_value: Optional[Any] = None) -> pd.DataFrame:\n    \"\"\"\n    Apply the transformer to the data via sklearn's `OneHotEncoder`'s `fit_transform` method. Name the new columns via manipulation of the original column name.\n    If `missing_value` is provided, fill missing values with this value before applying the transformer to ensure a new category is added.\n\n    Args:\n        data: The column of data to transform.\n        missing_value: The value learned by the `MetaTransformer` to represent missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n    \"\"\"\n    self.original_column_name = data.name\n    if missing_value:\n        data = data.fillna(missing_value)\n        self.missing_value = missing_value\n    transformed_data = pd.DataFrame(\n        self._transformer.fit_transform(data.values.reshape(-1, 1)),\n        columns=self._transformer.get_feature_names_out(input_features=[data.name]),\n    )\n    self.new_column_names = transformed_data.columns\n    return transformed_data\n
"},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer.revert","title":"revert(data)","text":"

Revert data to pre-transformer state via sklearn's OneHotEncoder's inverse_transform method. If missing_value is provided, replace instances of this value in the data with np.nan to ensure missing values are represented correctly in the case where missing_value was 'modelled' and thus generated.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.

Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Revert data to pre-transformer state via sklearn's `OneHotEncoder`'s `inverse_transform` method.\n    If `missing_value` is provided, replace instances of this value in the data with `np.nan` to ensure missing values are represented correctly in the case\n    where `missing_value` was 'modelled' and thus generated.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.\n    \"\"\"\n    data[self.original_column_name] = pd.Series(\n        self._transformer.inverse_transform(data[self.new_column_names].values).flatten(),\n        index=data.index,\n        name=self.original_column_name,\n    )\n    if self.missing_value:\n        data[self.original_column_name] = data[self.original_column_name].replace(self.missing_value, np.nan)\n    return data.drop(self.new_column_names, axis=1)\n
"},{"location":"reference/modules/dataloader/transformers/continuous/","title":"continuous","text":""},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer","title":"ClusterContinuousTransformer","text":"

Bases: ColumnTransformer

A transformer to cluster continuous features via sklearn's BayesianGaussianMixture. Essentially wraps the process of fitting the BGM model and generating cluster assignments and normalised values for the data to comply with the ColumnTransformer interface.

Parameters:

Name Type Description Default n_components int

The number of components to use in the BGM model.

10 n_init int

The number of initialisations to use in the BGM model.

1 init_params str

The initialisation method to use in the BGM model.

'kmeans' random_state int

The random state to use in the BGM model.

0 max_iter int

The maximum number of iterations to use in the BGM model.

1000 remove_unused_components bool

Whether to remove components that have no data assigned EXPERIMENTAL.

False clip_output bool

Whether to clip the output normalised values to the range [-1, 1].

False

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description means

The means of the components in the BGM model.

stds

The standard deviations of the components in the BGM model.

new_column_names

The names of the columns generated by the transformer (one for the normalised values and one for each cluster component).

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
class ClusterContinuousTransformer(ColumnTransformer):\n    \"\"\"\n    A transformer to cluster continuous features via sklearn's `BayesianGaussianMixture`.\n    Essentially wraps the process of fitting the BGM model and generating cluster assignments and normalised values for the data to comply with the `ColumnTransformer` interface.\n\n    Args:\n        n_components: The number of components to use in the BGM model.\n        n_init: The number of initialisations to use in the BGM model.\n        init_params: The initialisation method to use in the BGM model.\n        random_state: The random state to use in the BGM model.\n        max_iter: The maximum number of iterations to use in the BGM model.\n        remove_unused_components: Whether to remove components that have no data assigned EXPERIMENTAL.\n        clip_output: Whether to clip the output normalised values to the range [-1, 1].\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        means: The means of the components in the BGM model.\n        stds: The standard deviations of the components in the BGM model.\n        new_column_names: The names of the columns generated by the transformer (one for the normalised values and one for each cluster component).\n    \"\"\"\n\n    def __init__(\n        self,\n        n_components: int = 10,\n        n_init: int = 1,\n        init_params: str = \"kmeans\",\n        random_state: int = 0,\n        max_iter: int = 1000,\n        remove_unused_components: bool = False,\n        clip_output: bool = False,\n    ) -> None:\n        super().__init__()\n        self._transformer = BayesianGaussianMixture(\n            n_components=n_components,\n            random_state=random_state,\n            n_init=n_init,\n            init_params=init_params,\n            max_iter=max_iter,\n            weight_concentration_prior=1e-3,\n        )\n        self._n_components = n_components\n        self._std_multiplier = 4\n        self._missingness_column_name = None\n        self._max_iter = max_iter\n        self.remove_unused_components = remove_unused_components\n        self.clip_output = clip_output\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None) -> pd.DataFrame:\n        \"\"\"\n        Apply the transformer to the data via sklearn's `BayesianGaussianMixture`'s `fit` and `predict_proba` methods.\n        Name the new columns via the original column name.\n\n        If `missingness_column` is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0\n        (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.\n\n        Args:\n            data: The column of data to transform.\n            missingness_column: The column of data representing missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n\n        Returns:\n            The transformed data (will be multiple columns if `n_components` > 1 at initialisation).\n        \"\"\"\n        self.original_column_name = data.name\n        if missingness_column is not None:\n            self._missingness_column_name = missingness_column.name\n            full_index = data.index\n            data = data[missingness_column == 0]\n        index = data.index\n        data = np.array(data.values.reshape(-1, 1), dtype=data.dtype.name.lower())\n\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n            self._transformer.fit(data)\n\n        self.means = self._transformer.means_.reshape(-1)\n        self.stds = np.sqrt(self._transformer.covariances_).reshape(-1)\n\n        components = np.argmax(self._transformer.predict_proba(data), axis=1)\n        normalised_values = (data - self.means.reshape(1, -1)) / (self._std_multiplier * self.stds.reshape(1, -1))\n        normalised = normalised_values[np.arange(len(data)), components]\n        normalised = np.clip(normalised, -1.0, 1.0)\n        components = np.eye(self._n_components, dtype=int)[components]\n\n        transformed_data = pd.DataFrame(\n            np.hstack([normalised.reshape(-1, 1), components]),\n            index=index,\n            columns=[f\"{self.original_column_name}_normalised\"]\n            + [f\"{self.original_column_name}_c{i + 1}\" for i in range(self._n_components)],\n        )\n\n        # EXPERIMENTAL feature, removing components from the column matrix that have no data assigned to them\n        if self.remove_unused_components:\n            nunique = transformed_data.iloc[:, 1:].nunique(dropna=False)\n            unused_components = nunique[nunique == 1].index\n            unused_component_idx = [transformed_data.columns.get_loc(col_name) - 1 for col_name in unused_components]\n            self.means = np.delete(self.means, unused_component_idx)\n            self.stds = np.delete(self.stds, unused_component_idx)\n            transformed_data.drop(unused_components, axis=1, inplace=True)\n\n        if missingness_column is not None:\n            transformed_data = pd.concat([transformed_data.reindex(full_index).fillna(0.0), missingness_column], axis=1)\n\n        self.new_column_names = transformed_data.columns\n        return transformed_data.astype(\n            {col_name: int for col_name in transformed_data.columns if re.search(r\"_c\\d+\", col_name)}\n        )\n\n    def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the `new_column_names` attribute.\n        If `missingness_column` was provided to the `apply` method, drop the missing values from the data before reverting and use the `full_index` to\n        reintroduce missing values when `original_column_name` is constructed.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.\n        \"\"\"\n        working_data = data[self.new_column_names]\n        full_index = working_data.index\n        if self._missingness_column_name is not None:\n            working_data = working_data[working_data[self._missingness_column_name] == 0]\n            working_data = working_data.drop(self._missingness_column_name, axis=1)\n        index = working_data.index\n\n        components = np.argmax(working_data.filter(regex=r\".*_c\\d+\").values, axis=1)\n        working_data = working_data.filter(like=\"_normalised\").values.reshape(-1)\n        if self.clip_output:\n            working_data = np.clip(working_data, -1.0, 1.0)\n\n        mean_t = self.means[components]\n        std_t = self.stds[components]\n        data[self.original_column_name] = pd.Series(\n            working_data * self._std_multiplier * std_t + mean_t, index=index, name=self.original_column_name\n        ).reindex(full_index)\n        data.drop(self.new_column_names, axis=1, inplace=True)\n        return data\n
"},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer.apply","title":"apply(data, missingness_column=None)","text":"

Apply the transformer to the data via sklearn's BayesianGaussianMixture's fit and predict_proba methods. Name the new columns via the original column name.

If missingness_column is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0 (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missingness_column Optional[Series]

The column of data representing missingness, this is only used as part of the AugmentMissingnessStrategy.

None

Returns:

Type Description DataFrame

The transformed data (will be multiple columns if n_components > 1 at initialisation).

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None) -> pd.DataFrame:\n    \"\"\"\n    Apply the transformer to the data via sklearn's `BayesianGaussianMixture`'s `fit` and `predict_proba` methods.\n    Name the new columns via the original column name.\n\n    If `missingness_column` is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0\n    (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.\n\n    Args:\n        data: The column of data to transform.\n        missingness_column: The column of data representing missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n\n    Returns:\n        The transformed data (will be multiple columns if `n_components` > 1 at initialisation).\n    \"\"\"\n    self.original_column_name = data.name\n    if missingness_column is not None:\n        self._missingness_column_name = missingness_column.name\n        full_index = data.index\n        data = data[missingness_column == 0]\n    index = data.index\n    data = np.array(data.values.reshape(-1, 1), dtype=data.dtype.name.lower())\n\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n        self._transformer.fit(data)\n\n    self.means = self._transformer.means_.reshape(-1)\n    self.stds = np.sqrt(self._transformer.covariances_).reshape(-1)\n\n    components = np.argmax(self._transformer.predict_proba(data), axis=1)\n    normalised_values = (data - self.means.reshape(1, -1)) / (self._std_multiplier * self.stds.reshape(1, -1))\n    normalised = normalised_values[np.arange(len(data)), components]\n    normalised = np.clip(normalised, -1.0, 1.0)\n    components = np.eye(self._n_components, dtype=int)[components]\n\n    transformed_data = pd.DataFrame(\n        np.hstack([normalised.reshape(-1, 1), components]),\n        index=index,\n        columns=[f\"{self.original_column_name}_normalised\"]\n        + [f\"{self.original_column_name}_c{i + 1}\" for i in range(self._n_components)],\n    )\n\n    # EXPERIMENTAL feature, removing components from the column matrix that have no data assigned to them\n    if self.remove_unused_components:\n        nunique = transformed_data.iloc[:, 1:].nunique(dropna=False)\n        unused_components = nunique[nunique == 1].index\n        unused_component_idx = [transformed_data.columns.get_loc(col_name) - 1 for col_name in unused_components]\n        self.means = np.delete(self.means, unused_component_idx)\n        self.stds = np.delete(self.stds, unused_component_idx)\n        transformed_data.drop(unused_components, axis=1, inplace=True)\n\n    if missingness_column is not None:\n        transformed_data = pd.concat([transformed_data.reindex(full_index).fillna(0.0), missingness_column], axis=1)\n\n    self.new_column_names = transformed_data.columns\n    return transformed_data.astype(\n        {col_name: int for col_name in transformed_data.columns if re.search(r\"_c\\d+\", col_name)}\n    )\n
"},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer.revert","title":"revert(data)","text":"

Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the new_column_names attribute. If missingness_column was provided to the apply method, drop the missing values from the data before reverting and use the full_index to reintroduce missing values when original_column_name is constructed.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the `new_column_names` attribute.\n    If `missingness_column` was provided to the `apply` method, drop the missing values from the data before reverting and use the `full_index` to\n    reintroduce missing values when `original_column_name` is constructed.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.\n    \"\"\"\n    working_data = data[self.new_column_names]\n    full_index = working_data.index\n    if self._missingness_column_name is not None:\n        working_data = working_data[working_data[self._missingness_column_name] == 0]\n        working_data = working_data.drop(self._missingness_column_name, axis=1)\n    index = working_data.index\n\n    components = np.argmax(working_data.filter(regex=r\".*_c\\d+\").values, axis=1)\n    working_data = working_data.filter(like=\"_normalised\").values.reshape(-1)\n    if self.clip_output:\n        working_data = np.clip(working_data, -1.0, 1.0)\n\n    mean_t = self.means[components]\n    std_t = self.stds[components]\n    data[self.original_column_name] = pd.Series(\n        working_data * self._std_multiplier * std_t + mean_t, index=index, name=self.original_column_name\n    ).reindex(full_index)\n    data.drop(self.new_column_names, axis=1, inplace=True)\n    return data\n
"},{"location":"reference/modules/dataloader/transformers/datetime/","title":"datetime","text":""},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer","title":"DatetimeTransformer","text":"

Bases: TransformerWrapper

A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer. The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the AugmentMissingnessStrategy.

Parameters:

Name Type Description Default transformer ColumnTransformer

The ColumnTransformer to wrap.

required

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description original_column_name

The name of the original column.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
class DatetimeTransformer(TransformerWrapper):\n    \"\"\"\n    A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer.\n    The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the `AugmentMissingnessStrategy`.\n\n    Args:\n        transformer: The [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer] to wrap.\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        original_column_name: The name of the original column.\n    \"\"\"\n\n    def __init__(self, transformer: ColumnTransformer) -> None:\n        super().__init__(transformer)\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None, **kwargs) -> pd.DataFrame:\n        \"\"\"\n        Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch.\n        The float value of `pd.NaT` under the operation above is then replaced with `np.nan` to ensure missing values are represented correctly.\n        Finally, the wrapped transformer is applied to the data.\n\n        Args:\n            data: The column of data to transform.\n            missingness_column: The column of missingness indicators to augment the data with.\n\n        Returns:\n            The transformed data.\n        \"\"\"\n        self.original_column_name = data.name\n        floored_data = pd.Series(data.dt.floor(\"ns\").to_numpy().astype(float), name=data.name)\n        nan_corrected_data = floored_data.replace(pd.to_datetime(pd.NaT).to_numpy().astype(float), np.nan)\n        return super().apply(nan_corrected_data, missingness_column, **kwargs)\n\n    def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:\n        \"\"\"\n        The wrapped transformer's `revert` method is applied to the data. The data is then converted back to datetime format.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The reverted data.\n        \"\"\"\n        reverted_data = super().revert(data, **kwargs)\n        data[self.original_column_name] = pd.to_datetime(\n            reverted_data[self.original_column_name].astype(\"Int64\"), unit=\"ns\"\n        )\n        return data\n
"},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer.apply","title":"apply(data, missingness_column=None, **kwargs)","text":"

Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch. The float value of pd.NaT under the operation above is then replaced with np.nan to ensure missing values are represented correctly. Finally, the wrapped transformer is applied to the data.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missingness_column Optional[Series]

The column of missingness indicators to augment the data with.

None

Returns:

Type Description DataFrame

The transformed data.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None, **kwargs) -> pd.DataFrame:\n    \"\"\"\n    Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch.\n    The float value of `pd.NaT` under the operation above is then replaced with `np.nan` to ensure missing values are represented correctly.\n    Finally, the wrapped transformer is applied to the data.\n\n    Args:\n        data: The column of data to transform.\n        missingness_column: The column of missingness indicators to augment the data with.\n\n    Returns:\n        The transformed data.\n    \"\"\"\n    self.original_column_name = data.name\n    floored_data = pd.Series(data.dt.floor(\"ns\").to_numpy().astype(float), name=data.name)\n    nan_corrected_data = floored_data.replace(pd.to_datetime(pd.NaT).to_numpy().astype(float), np.nan)\n    return super().apply(nan_corrected_data, missingness_column, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer.revert","title":"revert(data, **kwargs)","text":"

The wrapped transformer's revert method is applied to the data. The data is then converted back to datetime format.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The reverted data.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:\n    \"\"\"\n    The wrapped transformer's `revert` method is applied to the data. The data is then converted back to datetime format.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The reverted data.\n    \"\"\"\n    reverted_data = super().revert(data, **kwargs)\n    data[self.original_column_name] = pd.to_datetime(\n        reverted_data[self.original_column_name].astype(\"Int64\"), unit=\"ns\"\n    )\n    return data\n
"},{"location":"reference/modules/evaluation/","title":"evaluation","text":""},{"location":"reference/modules/evaluation/aequitas/","title":"aequitas","text":""},{"location":"reference/modules/evaluation/io/","title":"io","text":""},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_typed str

The name of the typed real dataset file.

required fn_synthetic_datasets str

The filename of the collection of synethtic datasets.

required fn_sdv_metadata str

The name of the SDV metadata file.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/evaluation/io.py
def check_input_paths(\n    fn_dataset: str, fn_typed: str, fn_synthetic_datasets: str, fn_sdv_metadata: str, dir_experiment: Path\n) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_typed: The name of the typed real dataset file.\n        fn_synthetic_datasets: The filename of the collection of synethtic datasets.\n        fn_sdv_metadata: The name of the SDV metadata file.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_synthetic_datasets, fn_sdv_metadata = io.consistent_endings(\n        [fn_typed, fn_synthetic_datasets, fn_sdv_metadata]\n    )\n    fn_typed, fn_synthetic_datasets, fn_sdv_metadata = io.potential_suffixes(\n        [fn_typed, fn_synthetic_datasets, fn_sdv_metadata], fn_dataset\n    )\n    io.warn_if_path_supplied([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)\n    io.check_exists([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)\n    return fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata\n
"},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, DataFrame, dict[str, dict[str, Any]]]

The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata.

Source code in src/nhssynth/modules/evaluation/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, pd.DataFrame, dict[str, dict[str, Any]]]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"typed\", \"synthetic_datasets\", \"sdv_metadata\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"typed\"],\n            args.module_handover[\"synthetic_datasets\"],\n            args.module_handover[\"sdv_metadata\"],\n        )\n    else:\n        fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata = check_input_paths(\n            args.dataset, args.typed, args.synthetic_datasets, args.sdv_metadata, dir_experiment\n        )\n        with open(dir_experiment / fn_typed, \"rb\") as f:\n            real_data = pickle.load(f).contents\n        with open(dir_experiment / fn_sdv_metadata, \"rb\") as f:\n            sdv_metadata = pickle.load(f)\n        with open(dir_experiment / fn_synthetic_datasets, \"rb\") as f:\n            synthetic_datasets = pickle.load(f).contents\n\n        return fn_dataset, real_data, synthetic_datasets, sdv_metadata\n
"},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.output_eval","title":"output_eval(evaluations, fn_dataset, fn_evaluations, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default evaluations DataFrame

The evaluations to output.

required fn_dataset Path

The base name of the dataset.

required fn_evaluations str

The filename of the collection of evaluations.

required dir_experiment Path

The path to the experiment output directory.

required

Returns:

Type Description None

The path to output the model.

Source code in src/nhssynth/modules/evaluation/io.py
def output_eval(\n    evaluations: pd.DataFrame,\n    fn_dataset: Path,\n    fn_evaluations: str,\n    dir_experiment: Path,\n) -> None:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        evaluations: The evaluations to output.\n        fn_dataset: The base name of the dataset.\n        fn_evaluations: The filename of the collection of evaluations.\n        dir_experiment: The path to the experiment output directory.\n\n    Returns:\n        The path to output the model.\n    \"\"\"\n    fn_evaluations = io.consistent_ending(fn_evaluations)\n    fn_evaluations = io.potential_suffix(fn_evaluations, fn_dataset)\n    io.warn_if_path_supplied([fn_evaluations], dir_experiment)\n    with open(dir_experiment / fn_evaluations, \"wb\") as f:\n        pickle.dump(Evaluations(evaluations), f)\n
"},{"location":"reference/modules/evaluation/metrics/","title":"metrics","text":""},{"location":"reference/modules/evaluation/run/","title":"run","text":""},{"location":"reference/modules/evaluation/tasks/","title":"tasks","text":""},{"location":"reference/modules/evaluation/tasks/#nhssynth.modules.evaluation.tasks.Task","title":"Task","text":"

A task offers a light-touch way for users to specify any arbitrary downstream task that they want to run on a dataset.

Parameters:

Name Type Description Default name str

The name of the task.

required run Callable

The function to run.

required supports_aequitas

Whether the task supports Aequitas evaluation.

False description str

The description of the task.

'' Source code in src/nhssynth/modules/evaluation/tasks.py
class Task:\n    \"\"\"\n    A task offers a light-touch way for users to specify any arbitrary downstream task that they want to run on a dataset.\n\n    Args:\n        name: The name of the task.\n        run: The function to run.\n        supports_aequitas: Whether the task supports Aequitas evaluation.\n        description: The description of the task.\n    \"\"\"\n\n    def __init__(self, name: str, run: Callable, supports_aequitas=False, description: str = \"\"):\n        self._name: str = name\n        self._run: Callable = run\n        self._supports_aequitas: bool = supports_aequitas\n        self._description: str = description\n\n    def __str__(self) -> str:\n        return f\"{self.name}: {self.description}\" if self.description else self.name\n\n    def __repr__(self) -> str:\n        return str([self.name, self.run, self.supports_aequitas, self.description])\n\n    def run(self, *args, **kwargs):\n        return self._run(*args, **kwargs)\n
"},{"location":"reference/modules/evaluation/tasks/#nhssynth.modules.evaluation.tasks.get_tasks","title":"get_tasks(fn_dataset, tasks_root)","text":"

Searches for and imports all tasks in the tasks directory for a given dataset. Uses importlib to extract the task from the file.

Parameters:

Name Type Description Default fn_dataset str

The name of the dataset.

required tasks_root str

The root directory for downstream tasks.

required

Returns:

Type Description list[Task]

A list of tasks.

Source code in src/nhssynth/modules/evaluation/tasks.py
def get_tasks(\n    fn_dataset: str,\n    tasks_root: str,\n) -> list[Task]:\n    \"\"\"\n    Searches for and imports all tasks in the tasks directory for a given dataset.\n    Uses `importlib` to extract the task from the file.\n\n    Args:\n        fn_dataset: The name of the dataset.\n        tasks_root: The root directory for downstream tasks.\n\n    Returns:\n        A list of tasks.\n    \"\"\"\n    tasks_dir = Path(tasks_root) / fn_dataset\n    assert (\n        tasks_dir.exists()\n    ), f\"Downstream tasks directory does not exist ({tasks_dir}), NB there should be a directory in TASKS_DIR with the same name as the dataset.\"\n    tasks = []\n    for task_path in tasks_dir.iterdir():\n        if task_path.name.startswith((\".\", \"__\")):\n            continue\n        assert task_path.suffix == \".py\", f\"Downstream task file must be a python file ({task_path.name})\"\n        spec = importlib.util.spec_from_file_location(\n            \"nhssynth_task_\" + task_path.name, os.getcwd() + \"/\" + str(task_path)\n        )\n        task_module = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(task_module)\n        tasks.append(task_module.task)\n    return tasks\n
"},{"location":"reference/modules/evaluation/utils/","title":"utils","text":""},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame","title":"EvalFrame","text":"

Data structure for specifying and recording the evaluations of a set of synthetic datasets against a real dataset. All of the choices made by the user in the evaluation module are consolidated into this class.

After running evaluate on a set of synthetic datasets, the evaluations can be retrieved using get_evaluations. They are stored in a dict of dataframes with indices matching that of the supplied dataframe of synthetic datasets.

Parameters:

Name Type Description Default tasks list[Task]

A list of downstream tasks to run on the experiments.

required metrics list[str]

A list of metrics to calculate on the experiments.

required sdv_metadata dict[str, dict[str, str]]

The SDV metadata for the dataset.

required aequitas bool

Whether to run Aequitas on the results of supported downstream tasks.

False aequitas_attributes list[str]

The fairness-related attributes to use for Aequitas analysis.

[] key_numerical_fields list[str]

The numerical fields to use for SDV privacy metrics.

[] sensitive_numerical_fields list[str]

The numerical fields to use for SDV privacy metrics.

[] key_categorical_fields list[str]

The categorical fields to use for SDV privacy metrics.

[] sensitive_categorical_fields list[str]

The categorical fields to use for SDV privacy metrics.

[] Source code in src/nhssynth/modules/evaluation/utils.py
class EvalFrame:\n    \"\"\"\n    Data structure for specifying and recording the evaluations of a set of synthetic datasets against a real dataset.\n    All of the choices made by the user in the evaluation module are consolidated into this class.\n\n    After running `evaluate` on a set of synthetic datasets, the evaluations can be retrieved using `get_evaluations`.\n    They are stored in a dict of dataframes with indices matching that of the supplied dataframe of synthetic datasets.\n\n    Args:\n        tasks: A list of downstream tasks to run on the experiments.\n        metrics: A list of metrics to calculate on the experiments.\n        sdv_metadata: The SDV metadata for the dataset.\n        aequitas: Whether to run Aequitas on the results of supported downstream tasks.\n        aequitas_attributes: The fairness-related attributes to use for Aequitas analysis.\n        key_numerical_fields: The numerical fields to use for SDV privacy metrics.\n        sensitive_numerical_fields: The numerical fields to use for SDV privacy metrics.\n        key_categorical_fields: The categorical fields to use for SDV privacy metrics.\n        sensitive_categorical_fields: The categorical fields to use for SDV privacy metrics.\n    \"\"\"\n\n    def __init__(\n        self,\n        tasks: list[Task],\n        metrics: list[str],\n        sdv_metadata: dict[str, dict[str, str]],\n        aequitas: bool = False,\n        aequitas_attributes: list[str] = [],\n        key_numerical_fields: list[str] = [],\n        sensitive_numerical_fields: list[str] = [],\n        key_categorical_fields: list[str] = [],\n        sensitive_categorical_fields: list[str] = [],\n    ):\n        self._tasks = tasks\n        self._aequitas = aequitas\n        self._aequitas_attributes = aequitas_attributes\n\n        self._metrics = metrics\n        self._sdv_metadata = sdv_metadata\n\n        self._key_numerical_fields = key_numerical_fields\n        self._sensitive_numerical_fields = sensitive_numerical_fields\n        self._key_categorical_fields = key_categorical_fields\n        self._sensitive_categorical_fields = sensitive_categorical_fields\n        assert all([metric not in NUMERICAL_PRIVACY_METRICS for metric in self._metrics]) or (\n            self._key_numerical_fields and self._sensitive_numerical_fields\n        ), \"Numerical key and sensitive fields must be provided when an SDV privacy metric is used.\"\n        assert all([metric not in CATEGORICAL_PRIVACY_METRICS for metric in self._metrics]) or (\n            self._key_categorical_fields and self._sensitive_categorical_fields\n        ), \"Categorical key and sensitive fields must be provided when an SDV privacy metric is used.\"\n\n        self._metric_groups = self._build_metric_groups()\n\n    def _build_metric_groups(self) -> list[str]:\n        \"\"\"\n        Iterate through the concatenated list of metrics provided by the user and refer to the\n        [defined metric groups][nhssynth.common.constants] to identify which to evaluate.\n\n        Returns:\n            A list of metric groups to evaluate.\n        \"\"\"\n        metric_groups = set()\n        if self._tasks:\n            metric_groups.add(\"task\")\n        if self._aequitas:\n            metric_groups.add(\"aequitas\")\n        for metric in self._metrics:\n            if metric in TABLE_METRICS:\n                metric_groups.add(\"table\")\n            if metric in NUMERICAL_PRIVACY_METRICS or metric in CATEGORICAL_PRIVACY_METRICS:\n                metric_groups.add(\"privacy\")\n            if metric in TABLE_METRICS and issubclass(TABLE_METRICS[metric], MultiSingleColumnMetric):\n                metric_groups.add(\"columnwise\")\n            if metric in TABLE_METRICS and issubclass(TABLE_METRICS[metric], MultiColumnPairsMetric):\n                metric_groups.add(\"pairwise\")\n        return list(metric_groups)\n\n    def evaluate(self, real_dataset: pd.DataFrame, synthetic_datasets: list[dict[str, Any]]) -> None:\n        \"\"\"\n        Evaluate a set of synthetic datasets against a real dataset.\n\n        Args:\n            real_dataset: The real dataset to evaluate against.\n            synthetic_datasets: The synthetic datasets to evaluate.\n        \"\"\"\n        assert not any(\"Real\" in i for i in synthetic_datasets.index), \"Real is a reserved dataset ID.\"\n        assert synthetic_datasets.index.is_unique, \"Dataset IDs must be unique.\"\n        self._evaluations = pd.DataFrame(index=synthetic_datasets.index, columns=self._metric_groups)\n        self._evaluations.loc[(\"Real\", None, None)] = self._step(real_dataset)\n        pbar = tqdm(synthetic_datasets.iterrows(), desc=\"Evaluating\", total=len(synthetic_datasets))\n        for i, dataset in pbar:\n            pbar.set_description(f\"Evaluating {i[0]}, repeat {i[1]}, config {i[2]}\")\n            self._evaluations.loc[i] = self._step(real_dataset, dataset.values[0])\n\n    def get_evaluations(self) -> dict[str, pd.DataFrame]:\n        \"\"\"\n        Unpack the `self._evaluations` dataframe, where each metric group is a column, into a dict of dataframes.\n\n        Returns:\n            A dict of dataframes, one for each metric group, containing the evaluations.\n        \"\"\"\n        assert hasattr(\n            self, \"_evaluations\"\n        ), \"You must first run `evaluate` on a `real_dataset` and set of `synthetic_datasets`.\"\n        return {\n            metric_group: pd.DataFrame(\n                self._evaluations[metric_group].values.tolist(), index=self._evaluations.index\n            ).dropna(how=\"all\")\n            for metric_group in self._metric_groups\n        }\n\n    def _task_step(self, data: pd.DataFrame) -> dict[str, dict]:\n        \"\"\"\n        Run the downstream tasks on the dataset. Optionally run Aequitas on the results of the tasks.\n\n        Args:\n            data: The dataset to run the tasks on.\n\n        Returns:\n            A dict of dicts, one for each metric group, to be populated with each groups metric values.\n        \"\"\"\n        metric_dict = {metric_group: {} for metric_group in self._metric_groups}\n        for task in tqdm(self._tasks, desc=\"Running downstream tasks\", leave=False):\n            task_pred_column, task_metric_values = task.run(data)\n            metric_dict[\"task\"].update(task_metric_values)\n            if self._aequitas and task.supports_aequitas:\n                metric_dict[\"aequitas\"].update(run_aequitas(data[self._aequitas_attributes].join(task_pred_column)))\n        return metric_dict\n\n    def _compute_metric(\n        self, metric_dict: dict, metric: str, real_data: pd.DataFrame, synthetic_data: pd.DataFrame\n    ) -> dict[str, dict]:\n        \"\"\"\n        Given a metric, determine the correct way to evaluate it via the lists defined in `nhssynth.common.constants`.\n\n        Args:\n            metric_dict: The dict of dicts to populate with metric values.\n            metric: The metric to evaluate.\n            real_data: The real dataset to evaluate against.\n            synthetic_data: The synthetic dataset to evaluate.\n\n        Returns:\n            The metric_dict updated with the value of the metric.\n        \"\"\"\n        with pd.option_context(\"mode.chained_assignment\", None), warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"ConvergenceWarning\")\n            if metric in TABLE_METRICS:\n                metric_dict[\"table\"][metric] = TABLE_METRICS[metric].compute(\n                    real_data, synthetic_data, self._sdv_metadata\n                )\n                if issubclass(TABLE_METRICS[metric], MultiSingleColumnMetric):\n                    metric_dict[\"columnwise\"][metric] = TABLE_METRICS[metric].compute_breakdown(\n                        real_data, synthetic_data, self._sdv_metadata\n                    )\n                elif issubclass(TABLE_METRICS[metric], MultiColumnPairsMetric):\n                    metric_dict[\"pairwise\"][metric] = TABLE_METRICS[metric].compute_breakdown(\n                        real_data, synthetic_data, self._sdv_metadata\n                    )\n            elif metric in NUMERICAL_PRIVACY_METRICS:\n                metric_dict[\"privacy\"][metric] = NUMERICAL_PRIVACY_METRICS[metric].compute(\n                    real_data.dropna(),\n                    synthetic_data.dropna(),\n                    self._sdv_metadata,\n                    self._key_numerical_fields,\n                    self._sensitive_numerical_fields,\n                )\n            elif metric in CATEGORICAL_PRIVACY_METRICS:\n                metric_dict[\"privacy\"][metric] = CATEGORICAL_PRIVACY_METRICS[metric].compute(\n                    real_data.dropna(),\n                    synthetic_data.dropna(),\n                    self._sdv_metadata,\n                    self._key_categorical_fields,\n                    self._sensitive_categorical_fields,\n                )\n        return metric_dict\n\n    def _step(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame = None) -> dict[str, dict]:\n        \"\"\"\n        Run the two functions above (or only the tasks when no synthetic data is provided).\n\n        Args:\n            real_data: The real dataset to evaluate against.\n            synthetic_data: The synthetic dataset to evaluate.\n\n        Returns:\n            A dict of dicts, one for each metric grou, to populate a row of `self._evaluations` corresponding to the `synthetic_data`.\n        \"\"\"\n        if synthetic_data is None:\n            metric_dict = self._task_step(real_data)\n        else:\n            metric_dict = self._task_step(synthetic_data)\n            for metric in tqdm(self._metrics, desc=\"Running metrics\", leave=False):\n                metric_dict = self._compute_metric(metric_dict, metric, real_data, synthetic_data)\n        return metric_dict\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame.evaluate","title":"evaluate(real_dataset, synthetic_datasets)","text":"

Evaluate a set of synthetic datasets against a real dataset.

Parameters:

Name Type Description Default real_dataset DataFrame

The real dataset to evaluate against.

required synthetic_datasets list[dict[str, Any]]

The synthetic datasets to evaluate.

required Source code in src/nhssynth/modules/evaluation/utils.py
def evaluate(self, real_dataset: pd.DataFrame, synthetic_datasets: list[dict[str, Any]]) -> None:\n    \"\"\"\n    Evaluate a set of synthetic datasets against a real dataset.\n\n    Args:\n        real_dataset: The real dataset to evaluate against.\n        synthetic_datasets: The synthetic datasets to evaluate.\n    \"\"\"\n    assert not any(\"Real\" in i for i in synthetic_datasets.index), \"Real is a reserved dataset ID.\"\n    assert synthetic_datasets.index.is_unique, \"Dataset IDs must be unique.\"\n    self._evaluations = pd.DataFrame(index=synthetic_datasets.index, columns=self._metric_groups)\n    self._evaluations.loc[(\"Real\", None, None)] = self._step(real_dataset)\n    pbar = tqdm(synthetic_datasets.iterrows(), desc=\"Evaluating\", total=len(synthetic_datasets))\n    for i, dataset in pbar:\n        pbar.set_description(f\"Evaluating {i[0]}, repeat {i[1]}, config {i[2]}\")\n        self._evaluations.loc[i] = self._step(real_dataset, dataset.values[0])\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame.get_evaluations","title":"get_evaluations()","text":"

Unpack the self._evaluations dataframe, where each metric group is a column, into a dict of dataframes.

Returns:

Type Description dict[str, DataFrame]

A dict of dataframes, one for each metric group, containing the evaluations.

Source code in src/nhssynth/modules/evaluation/utils.py
def get_evaluations(self) -> dict[str, pd.DataFrame]:\n    \"\"\"\n    Unpack the `self._evaluations` dataframe, where each metric group is a column, into a dict of dataframes.\n\n    Returns:\n        A dict of dataframes, one for each metric group, containing the evaluations.\n    \"\"\"\n    assert hasattr(\n        self, \"_evaluations\"\n    ), \"You must first run `evaluate` on a `real_dataset` and set of `synthetic_datasets`.\"\n    return {\n        metric_group: pd.DataFrame(\n            self._evaluations[metric_group].values.tolist(), index=self._evaluations.index\n        ).dropna(how=\"all\")\n        for metric_group in self._metric_groups\n    }\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.validate_metric_args","title":"validate_metric_args(args, fn_dataset, columns)","text":"

Validate the arguments for downstream tasks and Aequitas.

Parameters:

Name Type Description Default args Namespace

The argument namespace to validate.

required fn_dataset str

The name of the dataset.

required columns Index

The columns in the dataset.

required

Returns:

Type Description tuple[list[Task], Namespace]

The validated arguments, the list of tasks and the list of metrics.

Source code in src/nhssynth/modules/evaluation/utils.py
def validate_metric_args(\n    args: argparse.Namespace, fn_dataset: str, columns: pd.Index\n) -> tuple[list[Task], argparse.Namespace]:\n    \"\"\"\n    Validate the arguments for downstream tasks and Aequitas.\n\n    Args:\n        args: The argument namespace to validate.\n        fn_dataset: The name of the dataset.\n        columns: The columns in the dataset.\n\n    Returns:\n        The validated arguments, the list of tasks and the list of metrics.\n    \"\"\"\n    if args.downstream_tasks:\n        tasks = get_tasks(fn_dataset, args.tasks_dir)\n        if not tasks:\n            warnings.warn(\"No valid downstream tasks found.\")\n    else:\n        tasks = []\n    if args.aequitas:\n        if not args.downstream_tasks or not any([task.supports_aequitas for task in tasks]):\n            warnings.warn(\n                \"Aequitas can only work in context of downstream tasks involving binary classification problems.\"\n            )\n        if not args.aequitas_attributes:\n            warnings.warn(\"No attributes specified for Aequitas analysis, defaulting to all columns in the dataset.\")\n            args.aequitas_attributes = columns.tolist()\n        assert all(\n            [attr in columns for attr in args.aequitas_attributes]\n        ), \"Invalid attribute(s) specified for Aequitas analysis.\"\n    metrics = {}\n    for metric_group in METRIC_CHOICES:\n        selected_metrics = getattr(args, \"_\".join(metric_group.split()).lower() + \"_metrics\") or []\n        metrics.update({metric_name: METRIC_CHOICES[metric_group][metric_name] for metric_name in selected_metrics})\n    return args, tasks, metrics\n
"},{"location":"reference/modules/model/","title":"model","text":""},{"location":"reference/modules/model/io/","title":"io","text":""},{"location":"reference/modules/model/io/#nhssynth.modules.model.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_transformed, fn_metatransformer, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_transformed str

The name of the transformed data file.

required fn_metatransformer str

The name of the metatransformer file.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/model/io.py
def check_input_paths(\n    fn_dataset: str, fn_transformed: str, fn_metatransformer: str, dir_experiment: Path\n) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_transformed: The name of the transformed data file.\n        fn_metatransformer: The name of the metatransformer file.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_transformed, fn_metatransformer = io.consistent_endings([fn_transformed, fn_metatransformer])\n    fn_transformed, fn_metatransformer = io.potential_suffixes([fn_transformed, fn_metatransformer], fn_dataset)\n    io.warn_if_path_supplied([fn_transformed, fn_metatransformer], dir_experiment)\n    io.check_exists([fn_transformed, fn_metatransformer], dir_experiment)\n    return fn_dataset, fn_transformed, fn_metatransformer\n
"},{"location":"reference/modules/model/io/#nhssynth.modules.model.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, dict[str, int], MetaTransformer]

The data, metadata and metatransformer.

Source code in src/nhssynth/modules/model/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, dict[str, int], MetaTransformer]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The data, metadata and metatransformer.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"transformed\", \"metatransformer\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"transformed\"],\n            args.module_handover[\"metatransformer\"],\n        )\n    else:\n        fn_dataset, fn_transformed, fn_metatransformer = check_input_paths(\n            args.dataset, args.transformed, args.metatransformer, dir_experiment\n        )\n\n        with open(dir_experiment / fn_transformed, \"rb\") as f:\n            data = pickle.load(f)\n        with open(dir_experiment / fn_metatransformer, \"rb\") as f:\n            mt = pickle.load(f)\n\n        return fn_dataset, data, mt\n
"},{"location":"reference/modules/model/run/","title":"run","text":""},{"location":"reference/modules/model/utils/","title":"utils","text":""},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.configs_from_arg_combinations","title":"configs_from_arg_combinations(args, arg_list)","text":"

Generates a list of configurations from a list of arguments. Each configuration is one of a cartesian product of the arguments provided and identified in arg_list.

Parameters:

Name Type Description Default args Namespace

The arguments.

required arg_list list[str]

The list of arguments to generate configurations from.

required

Returns:

Type Description list[dict[str, Any]]

A list of configurations.

Source code in src/nhssynth/modules/model/utils.py
def configs_from_arg_combinations(args: argparse.Namespace, arg_list: list[str]) -> list[dict[str, Any]]:\n    \"\"\"\n    Generates a list of configurations from a list of arguments. Each configuration is one of a cartesian product of\n    the arguments provided and identified in `arg_list`.\n\n    Args:\n        args: The arguments.\n        arg_list: The list of arguments to generate configurations from.\n\n    Returns:\n        A list of configurations.\n    \"\"\"\n    wrapped_args = {arg: wrap_arg(getattr(args, arg)) for arg in arg_list}\n    combinations = list(itertools.product(*wrapped_args.values()))\n    return [{k: v for k, v in zip(wrapped_args.keys(), values) if v is not None} for values in combinations]\n
"},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.get_experiments","title":"get_experiments(args)","text":"

Generates a dataframe of experiments from the arguments provided.

Parameters:

Name Type Description Default args Namespace

The arguments.

required

Returns:

Type Description DataFrame

A dataframe of experiments indexed by architecture, repeat and config ID.

Source code in src/nhssynth/modules/model/utils.py
def get_experiments(args: argparse.Namespace) -> pd.DataFrame:\n    \"\"\"\n    Generates a dataframe of experiments from the arguments provided.\n\n    Args:\n        args: The arguments.\n\n    Returns:\n        A dataframe of experiments indexed by architecture, repeat and config ID.\n    \"\"\"\n    experiments = pd.DataFrame(\n        columns=[\"architecture\", \"repeat\", \"config\", \"model_config\", \"seed\", \"train_config\", \"num_configs\"]\n    )\n    train_configs = configs_from_arg_combinations(args, [\"num_epochs\", \"patience\"])\n    for arch_name, repeat in itertools.product(*[wrap_arg(args.architecture), list(range(args.repeats))]):\n        arch = MODELS[arch_name]\n        model_configs = configs_from_arg_combinations(args, arch.get_args() + [\"batch_size\", \"use_gpu\"])\n        for i, (train_config, model_config) in enumerate(itertools.product(train_configs, model_configs)):\n            experiments.loc[len(experiments.index)] = {\n                \"architecture\": arch_name,\n                \"repeat\": repeat + 1,\n                \"config\": i + 1,\n                \"model_config\": model_config,\n                \"num_configs\": len(model_configs) * len(train_configs),\n                \"seed\": args.seed + repeat if args.seed else None,\n                \"train_config\": train_config,\n            }\n    return experiments.set_index([\"architecture\", \"repeat\", \"config\"], drop=True)\n
"},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.wrap_arg","title":"wrap_arg(arg)","text":"

Wraps a single argument in a list if it is not already a list or tuple.

Parameters:

Name Type Description Default arg Any

The argument to wrap.

required

Returns:

Type Description Union[list, tuple]

The wrapped argument.

Source code in src/nhssynth/modules/model/utils.py
def wrap_arg(arg: Any) -> Union[list, tuple]:\n    \"\"\"\n    Wraps a single argument in a list if it is not already a list or tuple.\n\n    Args:\n        arg: The argument to wrap.\n\n    Returns:\n        The wrapped argument.\n    \"\"\"\n    if not isinstance(arg, list) and not isinstance(arg, tuple):\n        return [arg]\n    return arg\n
"},{"location":"reference/modules/model/common/","title":"common","text":""},{"location":"reference/modules/model/common/dp/","title":"dp","text":""},{"location":"reference/modules/model/common/dp/#nhssynth.modules.model.common.dp.DPMixin","title":"DPMixin","text":"

Bases: ABC

Mixin class to make a Model differentially private

Parameters:

Name Type Description Default target_epsilon float

The target epsilon for the model during training

3.0 target_delta Optional[float]

The target delta for the model during training

None max_grad_norm float

The maximum norm for the gradients, they are trimmed to this norm if they are larger

5.0 secure_mode bool

Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the csprng package

False

Attributes:

Name Type Description target_epsilon float

The target epsilon for the model during training

target_delta float

The target delta for the model during training

max_grad_norm float

The maximum norm for the gradients, they are trimmed to this norm if they are larger

secure_mode bool

Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the csprng package

Raises:

Type Description TypeError

If the inheritor is not a Model

Source code in src/nhssynth/modules/model/common/dp.py
class DPMixin(ABC):\n    \"\"\"\n    Mixin class to make a [`Model`][nhssynth.modules.model.common.model.Model] differentially private\n\n    Args:\n        target_epsilon: The target epsilon for the model during training\n        target_delta: The target delta for the model during training\n        max_grad_norm: The maximum norm for the gradients, they are trimmed to this norm if they are larger\n        secure_mode: Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the `csprng` package\n\n    Attributes:\n        target_epsilon: The target epsilon for the model during training\n        target_delta: The target delta for the model during training\n        max_grad_norm: The maximum norm for the gradients, they are trimmed to this norm if they are larger\n        secure_mode: Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the `csprng` package\n\n    Raises:\n        TypeError: If the inheritor is not a `Model`\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        target_epsilon: float = 3.0,\n        target_delta: Optional[float] = None,\n        max_grad_norm: float = 5.0,\n        secure_mode: bool = False,\n        **kwargs,\n    ):\n        if not isinstance(self, Model):\n            raise TypeError(\"DPMixin can only be used with Model classes\")\n        super(DPMixin, self).__init__(*args, **kwargs)\n        self.target_epsilon: float = target_epsilon\n        self.target_delta: float = target_delta or 1 / self.nrows\n        self.max_grad_norm: float = max_grad_norm\n        self.secure_mode: bool = secure_mode\n\n    def make_private(self, num_epochs: int, module: Optional[nn.Module] = None) -> GradSampleModule:\n        \"\"\"\n        Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.\n\n        Args:\n            num_epochs: The number of epochs to train for, used to calculate the privacy budget.\n            module: The module to make private.\n\n        Returns:\n            The privatised module.\n        \"\"\"\n        module = module or self\n        self.privacy_engine = PrivacyEngine(secure_mode=self.secure_mode)\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n            warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n            module, module.optim, self.data_loader = self.privacy_engine.make_private_with_epsilon(\n                module=module,\n                optimizer=module.optim,\n                data_loader=self.data_loader,\n                epochs=num_epochs,\n                target_epsilon=self.target_epsilon,\n                target_delta=self.target_delta,\n                max_grad_norm=self.max_grad_norm,\n            )\n        print(\n            f\"Using sigma={module.optim.noise_multiplier} and C={self.max_grad_norm} to target (\u03b5, \u03b4) = ({self.target_epsilon}, {self.target_delta})-differential privacy.\".format()\n        )\n        self.get_epsilon = self.privacy_engine.accountant.get_epsilon\n        return module\n\n    def _generate_metric_str(self, key) -> str:\n        \"\"\"Generates a string to display the current value of the metric `key`.\"\"\"\n        if key == \"Privacy\":\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n                warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n                val = self.get_epsilon(self.target_delta)\n            self.metrics[key] = np.append(self.metrics[key], val)\n            return f\"{(key + ' \u03b5 Spent:').ljust(self.max_length)}  {val:.4f}\"\n        else:\n            return super()._generate_metric_str(key)\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\"target_epsilon\", \"target_delta\", \"max_grad_norm\", \"secure_mode\"]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\"Privacy\"]\n\n    def _start_training(self, num_epochs, patience, displayed_metrics):\n        self.make_private(num_epochs)\n        super()._start_training(num_epochs, patience, displayed_metrics)\n
"},{"location":"reference/modules/model/common/dp/#nhssynth.modules.model.common.dp.DPMixin.make_private","title":"make_private(num_epochs, module=None)","text":"

Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.

Parameters:

Name Type Description Default num_epochs int

The number of epochs to train for, used to calculate the privacy budget.

required module Optional[Module]

The module to make private.

None

Returns:

Type Description GradSampleModule

The privatised module.

Source code in src/nhssynth/modules/model/common/dp.py
def make_private(self, num_epochs: int, module: Optional[nn.Module] = None) -> GradSampleModule:\n    \"\"\"\n    Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.\n\n    Args:\n        num_epochs: The number of epochs to train for, used to calculate the privacy budget.\n        module: The module to make private.\n\n    Returns:\n        The privatised module.\n    \"\"\"\n    module = module or self\n    self.privacy_engine = PrivacyEngine(secure_mode=self.secure_mode)\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n        warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n        module, module.optim, self.data_loader = self.privacy_engine.make_private_with_epsilon(\n            module=module,\n            optimizer=module.optim,\n            data_loader=self.data_loader,\n            epochs=num_epochs,\n            target_epsilon=self.target_epsilon,\n            target_delta=self.target_delta,\n            max_grad_norm=self.max_grad_norm,\n        )\n    print(\n        f\"Using sigma={module.optim.noise_multiplier} and C={self.max_grad_norm} to target (\u03b5, \u03b4) = ({self.target_epsilon}, {self.target_delta})-differential privacy.\".format()\n    )\n    self.get_epsilon = self.privacy_engine.accountant.get_epsilon\n    return module\n
"},{"location":"reference/modules/model/common/mlp/","title":"mlp","text":""},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MLP","title":"MLP","text":"

Bases: Module

Fully connected or residual neural nets for classification and regression.

"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MLP--parameters","title":"Parameters","text":"

task_type: str classification or regression n_units_int: int Number of features n_units_out: int Number of outputs n_layers_hidden: int Number of hidden layers n_units_hidden: int Number of hidden units in each layer nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu', 'tanh' or 'leaky_relu'. lr: float learning rate for optimizer. weight_decay: float l2 (ridge) penalty for the weights. n_iter: int Maximum number of iterations. batch_size: int Batch size n_iter_print: int Number of iterations after which to print updates and check the validation loss. random_state: int random_state used patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping dropout: float Dropout value. If 0, the dropout is not used. clipping_value: int, default 1 Gradients clipping value batch_norm: bool Enable/disable batch norm early_stopping: bool Enable/disable early stopping residual: bool Add residuals. loss: Callable Optional Custom loss function. If None, the loss is CrossEntropy for classification tasks, or RMSE for regression.

Source code in src/nhssynth/modules/model/common/mlp.py
class MLP(nn.Module):\n    \"\"\"\n    Fully connected or residual neural nets for classification and regression.\n\n    Parameters\n    ----------\n    task_type: str\n        classification or regression\n    n_units_int: int\n        Number of features\n    n_units_out: int\n        Number of outputs\n    n_layers_hidden: int\n        Number of hidden layers\n    n_units_hidden: int\n        Number of hidden units in each layer\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu', 'tanh' or 'leaky_relu'.\n    lr: float\n        learning rate for optimizer.\n    weight_decay: float\n        l2 (ridge) penalty for the weights.\n    n_iter: int\n        Maximum number of iterations.\n    batch_size: int\n        Batch size\n    n_iter_print: int\n        Number of iterations after which to print updates and check the validation loss.\n    random_state: int\n        random_state used\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    dropout: float\n        Dropout value. If 0, the dropout is not used.\n    clipping_value: int, default 1\n        Gradients clipping value\n    batch_norm: bool\n        Enable/disable batch norm\n    early_stopping: bool\n        Enable/disable early stopping\n    residual: bool\n        Add residuals.\n    loss: Callable\n        Optional Custom loss function. If None, the loss is CrossEntropy for classification tasks, or RMSE for regression.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_units_in: int,\n        n_units_out: int,\n        n_layers_hidden: int = 1,\n        n_units_hidden: int = 100,\n        activation: str = \"relu\",\n        activation_out: Optional[list[tuple[str, int]]] = None,\n        lr: float = 1e-3,\n        weight_decay: float = 1e-3,\n        opt_betas: tuple = (0.9, 0.999),\n        n_iter: int = 1000,\n        batch_size: int = 500,\n        n_iter_print: int = 100,\n        patience: int = 10,\n        n_iter_min: int = 100,\n        dropout: float = 0.1,\n        clipping_value: int = 1,\n        batch_norm: bool = False,\n        early_stopping: bool = True,\n        residual: bool = False,\n        loss: Optional[Callable] = None,\n    ) -> None:\n        super(MLP, self).__init__()\n        activation = ACTIVATION_FUNCTIONS[activation] if activation in ACTIVATION_FUNCTIONS else None\n\n        if n_units_in < 0:\n            raise ValueError(\"n_units_in must be >= 0\")\n        if n_units_out < 0:\n            raise ValueError(\"n_units_out must be >= 0\")\n\n        if residual:\n            block = ResidualLayer\n        else:\n            block = LinearLayer\n\n        # network\n        layers = []\n\n        if n_layers_hidden > 0:\n            layers.append(\n                block(\n                    n_units_in,\n                    n_units_hidden,\n                    batch_norm=batch_norm,\n                    activation=activation,\n                )\n            )\n            n_units_hidden += int(residual) * n_units_in\n\n            # add required number of layers\n            for i in range(n_layers_hidden - 1):\n                layers.append(\n                    block(\n                        n_units_hidden,\n                        n_units_hidden,\n                        batch_norm=batch_norm,\n                        activation=activation,\n                        dropout=dropout,\n                    )\n                )\n                n_units_hidden += int(residual) * n_units_hidden\n\n            # add final layers\n            layers.append(nn.Linear(n_units_hidden, n_units_out))\n        else:\n            layers = [nn.Linear(n_units_in, n_units_out)]\n\n        if activation_out is not None:\n            total_nonlin_len = 0\n            activations = []\n            for nonlin, nonlin_len in activation_out:\n                total_nonlin_len += nonlin_len\n                activations.append((ACTIVATION_FUNCTIONS[nonlin](), nonlin_len))\n\n            if total_nonlin_len != n_units_out:\n                raise RuntimeError(\n                    f\"Shape mismatch for the output layer. Expected length {n_units_out}, but got {activation_out} with length {total_nonlin_len}\"\n                )\n            layers.append(MultiActivationHead(activations))\n\n        self.model = nn.Sequential(*layers)\n\n        # optimizer\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.opt_betas = opt_betas\n        self.optimizer = torch.optim.Adam(\n            self.parameters(),\n            lr=self.lr,\n            weight_decay=self.weight_decay,\n            betas=self.opt_betas,\n        )\n\n        # training\n        self.n_iter = n_iter\n        self.n_iter_print = n_iter_print\n        self.n_iter_min = n_iter_min\n        self.batch_size = batch_size\n        self.patience = patience\n        self.clipping_value = clipping_value\n        self.early_stopping = early_stopping\n        if loss is not None:\n            self.loss = loss\n        else:\n            self.loss = nn.MSELoss()\n\n    def fit(self, X: np.ndarray, y: np.ndarray) -> \"MLP\":\n        Xt = self._check_tensor(X)\n        yt = self._check_tensor(y)\n\n        self._train(Xt, yt)\n\n        return self\n\n    def predict_proba(self, X: np.ndarray) -> np.ndarray:\n        if self.task_type != \"classification\":\n            raise ValueError(f\"Invalid task type for predict_proba {self.task_type}\")\n\n        with torch.no_grad():\n            Xt = self._check_tensor(X)\n\n            yt = self.forward(Xt)\n\n            return yt.cpu().numpy().squeeze()\n\n    def predict(self, X: np.ndarray) -> np.ndarray:\n        with torch.no_grad():\n            Xt = self._check_tensor(X)\n\n            yt = self.forward(Xt)\n\n            if self.task_type == \"classification\":\n                return np.argmax(yt.cpu().numpy().squeeze(), -1).squeeze()\n            else:\n                return yt.cpu().numpy().squeeze()\n\n    def score(self, X: np.ndarray, y: np.ndarray) -> float:\n        y_pred = self.predict(X)\n        if self.task_type == \"classification\":\n            return np.mean(y_pred == y)\n        else:\n            return np.mean(np.inner(y - y_pred, y - y_pred) / 2.0)\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        return self.model(X.float())\n\n    def _train_epoch(self, loader: DataLoader) -> float:\n        train_loss = []\n\n        for batch_ndx, sample in enumerate(loader):\n            self.optimizer.zero_grad()\n\n            X_next, y_next = sample\n            if len(X_next) < 2:\n                continue\n\n            preds = self.forward(X_next).squeeze()\n\n            batch_loss = self.loss(preds, y_next)\n\n            batch_loss.backward()\n\n            if self.clipping_value > 0:\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\n\n            self.optimizer.step()\n\n            train_loss.append(batch_loss.detach())\n\n        return torch.mean(torch.Tensor(train_loss))\n\n    def _train(self, X: torch.Tensor, y: torch.Tensor) -> \"MLP\":\n        X = self._check_tensor(X).float()\n        y = self._check_tensor(y).squeeze().float()\n        if self.task_type == \"classification\":\n            y = y.long()\n\n        # Load Dataset\n        dataset = TensorDataset(X, y)\n\n        train_size = int(0.8 * len(dataset))\n        test_size = len(dataset) - train_size\n        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])\n        loader = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=False)\n\n        # Setup the network and optimizer\n        val_loss_best = 1e12\n        patience = 0\n\n        # do training\n        for i in range(self.n_iter):\n            self._train_epoch(loader)\n\n            if self.early_stopping or i % self.n_iter_print == 0:\n                with torch.no_grad():\n                    X_val, y_val = test_dataset.dataset.tensors\n\n                    preds = self.forward(X_val).squeeze()\n                    val_loss = self.loss(preds, y_val)\n\n                    if self.early_stopping:\n                        if val_loss_best > val_loss:\n                            val_loss_best = val_loss\n                            patience = 0\n                        else:\n                            patience += 1\n\n                        if patience > self.patience and i > self.n_iter_min:\n                            break\n\n        return self\n\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\n        if isinstance(X, torch.Tensor):\n            return X\n        else:\n            return torch.from_numpy(np.asarray(X))\n\n    def __len__(self) -> int:\n        return len(self.model)\n
"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MultiActivationHead","title":"MultiActivationHead","text":"

Bases: Module

Final layer with multiple activations. Useful for tabular data.

Source code in src/nhssynth/modules/model/common/mlp.py
class MultiActivationHead(nn.Module):\n    \"\"\"Final layer with multiple activations. Useful for tabular data.\"\"\"\n\n    def __init__(\n        self,\n        activations: list[tuple[nn.Module, int]],\n    ) -> None:\n        super(MultiActivationHead, self).__init__()\n        self.activations = []\n        self.activation_lengths = []\n\n        for activation, length in activations:\n            self.activations.append(activation)\n            self.activation_lengths.append(length)\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        if X.shape[-1] != np.sum(self.activation_lengths):\n            raise RuntimeError(\n                f\"Shape mismatch for the activations: expected {np.sum(self.activation_lengths)}. Got shape {X.shape}.\"\n            )\n\n        split = 0\n        out = torch.zeros(X.shape)\n\n        for activation, step in zip(self.activations, self.activation_lengths):\n            out[..., split : split + step] = activation(X[..., split : split + step])\n            split += step\n\n        return out\n
"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.SkipConnection","title":"SkipConnection(cls)","text":"

Wraps a model to add a skip connection from the input to the output.

Example:

ResidualBlock = SkipConnection(MLP) res_block = ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64) res_block(torch.ones(10, 10)).shape (10, 13)

Source code in src/nhssynth/modules/model/common/mlp.py
def SkipConnection(cls: Type[nn.Module]) -> Type[nn.Module]:\n    \"\"\"Wraps a model to add a skip connection from the input to the output.\n\n    Example:\n    >>> ResidualBlock = SkipConnection(MLP)\n    >>> res_block = ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64)\n    >>> res_block(torch.ones(10, 10)).shape\n    (10, 13)\n    \"\"\"\n\n    class Wrapper(cls):\n        pass\n\n    Wrapper._forward = cls.forward\n    Wrapper.forward = _forward_skip_connection\n    Wrapper.__name__ = f\"SkipConnection({cls.__name__})\"\n    Wrapper.__qualname__ = f\"SkipConnection({cls.__qualname__})\"\n    Wrapper.__doc__ = f\"\"\"(With skipped connection) {cls.__doc__}\"\"\"\n    return Wrapper\n
"},{"location":"reference/modules/model/common/model/","title":"model","text":""},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model","title":"Model","text":"

Bases: Module, ABC

Abstract base class for all NHSSynth models

Parameters:

Name Type Description Default data DataFrame

The data to train on

required metatransformer MetaTransformer

A MetaTransformer to use for converting the generated data to match the original data

required batch_size int

The batch size to use during training

32 use_gpu bool

Flag to determine whether to use the GPU (if available)

False

Attributes:

Name Type Description nrows

The number of rows in the data

ncols

The number of columns in the data

columns Index

The names of the columns in the data

metatransformer

The MetaTransformer (potentially) associated with the model

multi_column_indices list[list[int]]

A list of lists of column indices, where each sublist containts the indices for a one-hot encoded column

single_column_indices list[int]

Indices of all non-onehot columns

data_loader DataLoader

A PyTorch DataLoader for the data

private DataLoader

Whether the model is private, i.e. whether the DPMixin class has been inherited

device DataLoader

The device to use for training (CPU or GPU)

Raises:

Type Description TypeError

If the Model class is directly instantiated (i.e. not inherited)

AssertionError

If the number of columns in the data does not match the number of indices in multi_column_indices and single_column_indices

UserWarning

If use_gpu is True but no GPU is available

Source code in src/nhssynth/modules/model/common/model.py
class Model(nn.Module, ABC):\n    \"\"\"\n    Abstract base class for all NHSSynth models\n\n    Args:\n        data: The data to train on\n        metatransformer: A `MetaTransformer` to use for converting the generated data to match the original data\n        batch_size: The batch size to use during training\n        use_gpu: Flag to determine whether to use the GPU (if available)\n\n    Attributes:\n        nrows: The number of rows in the `data`\n        ncols: The number of columns in the `data`\n        columns: The names of the columns in the `data`\n        metatransformer: The `MetaTransformer` (potentially) associated with the model\n        multi_column_indices: A list of lists of column indices, where each sublist containts the indices for a one-hot encoded column\n        single_column_indices: Indices of all non-onehot columns\n        data_loader: A PyTorch DataLoader for the `data`\n        private: Whether the model is private, i.e. whether the `DPMixin` class has been inherited\n        device: The device to use for training (CPU or GPU)\n\n    Raises:\n        TypeError: If the `Model` class is directly instantiated (i.e. not inherited)\n        AssertionError: If the number of columns in the `data` does not match the number of indices in `multi_column_indices` and `single_column_indices`\n        UserWarning: If `use_gpu` is True but no GPU is available\n    \"\"\"\n\n    def __init__(\n        self,\n        data: pd.DataFrame,\n        metatransformer: MetaTransformer,\n        cond: Optional[Union[pd.DataFrame, pd.Series, np.ndarray]] = None,\n        batch_size: int = 32,\n        use_gpu: bool = False,\n    ) -> None:\n        if type(self) is Model:\n            raise TypeError(\"Cannot directly instantiate the `Model` class\")\n        super().__init__()\n\n        self.nrows, self.ncols = data.shape\n        self.columns: pd.Index = data.columns\n\n        self.batch_size = batch_size\n\n        self.metatransformer = metatransformer\n        self.multi_column_indices: list[list[int]] = metatransformer.multi_column_indices\n        self.single_column_indices: list[int] = metatransformer.single_column_indices\n        assert len(self.single_column_indices) + sum([len(x) for x in self.multi_column_indices]) == self.ncols\n\n        tensor_data = torch.Tensor(data.to_numpy())\n        self.cond_encoder: Optional[OneHotEncoder] = None\n        if cond is not None:\n            cond = np.asarray(cond)\n            if len(cond.shape) == 1:\n                cond = cond.reshape(-1, 1)\n            self.cond_encoder = OneHotEncoder(handle_unknown=\"ignore\").fit(cond)\n            cond = self.cond_encoder.transform(cond).toarray()\n            self.n_units_conditional = cond.shape[-1]\n            dataset = TensorDataset(tensor_data, cond)\n        else:\n            self.n_units_conditional = 0\n            dataset = TensorDataset(tensor_data)\n\n        self.data_loader: DataLoader = DataLoader(\n            dataset,\n            pin_memory=True,\n            batch_size=self.batch_size,\n        )\n        self.setup_device(use_gpu)\n\n    def setup_device(self, use_gpu: bool) -> None:\n        \"\"\"Sets up the device to use for training (CPU or GPU) depending on `use_gpu` and device availability.\"\"\"\n        if use_gpu:\n            if torch.cuda.is_available():\n                self.device: torch.device = torch.device(\"cuda:0\")\n            else:\n                warnings.warn(\"`use_gpu` was provided but no GPU is available, using CPU\")\n        self.device: torch.device = torch.device(\"cpu\")\n\n    def save(self, filename: str) -> None:\n        \"\"\"Saves the model to `filename`.\"\"\"\n        torch.save(self.state_dict(), filename)\n\n    def load(self, path: str) -> None:\n        \"\"\"Loads the model from `path`.\"\"\"\n        self.load_state_dict(torch.load(path))\n\n    @classmethod\n    @abstractmethod\n    def get_args() -> list[str]:\n        \"\"\"Returns the list of arguments to look for in an `argparse.Namespace`, these must map to the arguments of the inheritor.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def get_metrics() -> list[str]:\n        \"\"\"Returns the list of metrics to track during training.\"\"\"\n        raise NotImplementedError\n\n    def _start_training(self, num_epochs: int, patience: int, displayed_metrics: list[str]) -> None:\n        \"\"\"\n        Initialises the training process.\n\n        Args:\n            num_epochs: The number of epochs to train for\n            patience: The number of epochs to wait before stopping training early if the loss does not improve\n            displayed_metrics: The metrics to display during training, this should be set to an empty list if running `train` in a notebook or the output may be messy\n\n        Attributes:\n            metrics: A dictionary of lists of tracked metrics, where each list contains the values for each batch\n            stats_bars: A dictionary of tqdm status bars for each tracked metric\n            max_length: The maximum length of the tracked metric names, used for formatting the tqdm status bars\n            start_time: The time at which training started\n            update_time: The time at which the tqdm status bars were last updated\n        \"\"\"\n        self.num_epochs = num_epochs\n        self.patience = patience\n        self.metrics = {metric: np.empty(0, dtype=float) for metric in self.get_metrics()}\n        displayed_metrics = displayed_metrics or self.get_metrics()\n        self.stats_bars = {\n            metric: tqdm(total=0, desc=\"\", position=i, bar_format=\"{desc}\", leave=True)\n            for i, metric in enumerate(displayed_metrics)\n        }\n        self.max_length = max([len(add_spaces_before_caps(s)) + 5 for s in displayed_metrics] + [20])\n        self.start_time = self.update_time = time.time()\n\n    def _generate_metric_str(self, key) -> str:\n        \"\"\"Generates a string to display the current value of the metric `key`.\"\"\"\n        return f\"{(add_spaces_before_caps(key) + ':').ljust(self.max_length)}  {np.mean(self.metrics[key][-len(self.data_loader) :]):.4f}\"\n\n    def _record_metrics(self, losses):\n        \"\"\"Records the metrics for the current batch to file and updates the tqdm status bars.\"\"\"\n        for key in self.metrics.keys():\n            if key in losses:\n                if losses[key]:\n                    self.metrics[key] = np.append(\n                        self.metrics[key], losses[key].item() if isinstance(losses[key], torch.Tensor) else losses[key]\n                    )\n        if time.time() - self.update_time > 0.5:\n            for key, stats_bar in self.stats_bars.items():\n                stats_bar.set_description_str(self._generate_metric_str(key))\n                self.update_time = time.time()\n\n    def _check_patience(self, epoch: int, metric: float) -> bool:\n        \"\"\"Maintains `_min_metric` and `_stop_counter` to determine whether to stop training early according to `patience`.\"\"\"\n        if epoch == 0:\n            self._stop_counter = 0\n            self._min_metric = metric\n            self._patience_delta = self._min_metric / 1e4\n        if metric < (self._min_metric - self._patience_delta):\n            self._min_metric = metric\n            self._stop_counter = 0  # Set counter to zero\n        else:  # elbo has not improved\n            self._stop_counter += 1\n        return self._stop_counter == self.patience\n\n    def _finish_training(self, num_epochs: int) -> None:\n        \"\"\"Closes each of the tqdm status bars and prints the time taken to do `num_epochs`.\"\"\"\n        for stats_bar in self.stats_bars.values():\n            stats_bar.close()\n        tqdm.write(f\"Completed {num_epochs} epochs in {time.time() - self.start_time:.2f} seconds.\\033[0m\")\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.get_args","title":"get_args() abstractmethod classmethod","text":"

Returns the list of arguments to look for in an argparse.Namespace, these must map to the arguments of the inheritor.

Source code in src/nhssynth/modules/model/common/model.py
@classmethod\n@abstractmethod\ndef get_args() -> list[str]:\n    \"\"\"Returns the list of arguments to look for in an `argparse.Namespace`, these must map to the arguments of the inheritor.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.get_metrics","title":"get_metrics() abstractmethod classmethod","text":"

Returns the list of metrics to track during training.

Source code in src/nhssynth/modules/model/common/model.py
@classmethod\n@abstractmethod\ndef get_metrics() -> list[str]:\n    \"\"\"Returns the list of metrics to track during training.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.load","title":"load(path)","text":"

Loads the model from path.

Source code in src/nhssynth/modules/model/common/model.py
def load(self, path: str) -> None:\n    \"\"\"Loads the model from `path`.\"\"\"\n    self.load_state_dict(torch.load(path))\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.save","title":"save(filename)","text":"

Saves the model to filename.

Source code in src/nhssynth/modules/model/common/model.py
def save(self, filename: str) -> None:\n    \"\"\"Saves the model to `filename`.\"\"\"\n    torch.save(self.state_dict(), filename)\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.setup_device","title":"setup_device(use_gpu)","text":"

Sets up the device to use for training (CPU or GPU) depending on use_gpu and device availability.

Source code in src/nhssynth/modules/model/common/model.py
def setup_device(self, use_gpu: bool) -> None:\n    \"\"\"Sets up the device to use for training (CPU or GPU) depending on `use_gpu` and device availability.\"\"\"\n    if use_gpu:\n        if torch.cuda.is_available():\n            self.device: torch.device = torch.device(\"cuda:0\")\n        else:\n            warnings.warn(\"`use_gpu` was provided but no GPU is available, using CPU\")\n    self.device: torch.device = torch.device(\"cpu\")\n
"},{"location":"reference/modules/model/models/","title":"models","text":""},{"location":"reference/modules/model/models/dpvae/","title":"dpvae","text":""},{"location":"reference/modules/model/models/dpvae/#nhssynth.modules.model.models.dpvae.DPVAE","title":"DPVAE","text":"

Bases: DPMixin, VAE

A differentially private VAE. Accepts VAE arguments as well as DPMixin arguments.

Source code in src/nhssynth/modules/model/models/dpvae.py
class DPVAE(DPMixin, VAE):\n    \"\"\"\n    A differentially private VAE. Accepts [`VAE`][nhssynth.modules.model.models.vae.VAE] arguments\n    as well as [`DPMixin`][nhssynth.modules.model.common.dp.DPMixin] arguments.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        target_epsilon: float = 3.0,\n        target_delta: Optional[float] = None,\n        max_grad_norm: float = 5.0,\n        secure_mode: bool = False,\n        shared_optimizer: bool = False,\n        **kwargs,\n    ) -> None:\n        super(DPVAE, self).__init__(\n            *args,\n            target_epsilon=target_epsilon,\n            target_delta=target_delta,\n            max_grad_norm=max_grad_norm,\n            secure_mode=secure_mode,\n            # TODO fix shared_optimizer workflow for DP models\n            shared_optimizer=False,\n            **kwargs,\n        )\n\n    def make_private(self, num_epochs: int) -> GradSampleModule:\n        \"\"\"\n        Make the [`Decoder`][nhssynth.modules.model.models.vae.Decoder] differentially private\n        unless `shared_optimizer` is True, in which case the whole VAE will be privatised.\n\n        Args:\n            num_epochs: The number of epochs to train for\n        \"\"\"\n        if self.shared_optimizer:\n            super().make_private(num_epochs)\n        else:\n            self.decoder = super().make_private(num_epochs, self.decoder)\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return VAE.get_args() + DPMixin.get_args()\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return VAE.get_metrics() + DPMixin.get_metrics()\n
"},{"location":"reference/modules/model/models/dpvae/#nhssynth.modules.model.models.dpvae.DPVAE.make_private","title":"make_private(num_epochs)","text":"

Make the Decoder differentially private unless shared_optimizer is True, in which case the whole VAE will be privatised.

Parameters:

Name Type Description Default num_epochs int

The number of epochs to train for

required Source code in src/nhssynth/modules/model/models/dpvae.py
def make_private(self, num_epochs: int) -> GradSampleModule:\n    \"\"\"\n    Make the [`Decoder`][nhssynth.modules.model.models.vae.Decoder] differentially private\n    unless `shared_optimizer` is True, in which case the whole VAE will be privatised.\n\n    Args:\n        num_epochs: The number of epochs to train for\n    \"\"\"\n    if self.shared_optimizer:\n        super().make_private(num_epochs)\n    else:\n        self.decoder = super().make_private(num_epochs, self.decoder)\n
"},{"location":"reference/modules/model/models/gan/","title":"gan","text":""},{"location":"reference/modules/model/models/gan/#nhssynth.modules.model.models.gan.GAN","title":"GAN","text":"

Bases: Model

Basic GAN implementation.

Parameters:

Name Type Description Default n_units_conditional int

int Number of conditional units

0 generator_n_layers_hidden int

int Number of hidden layers in the generator

2 generator_n_units_hidden int

int Number of hidden units in each layer of the Generator

250 generator_activation str

string, default 'elu' Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.

'leaky_relu' generator_n_iter

int Maximum number of iterations in the Generator.

required generator_batch_norm bool

bool Enable/disable batch norm for the generator

False generator_dropout float

float Dropout value. If 0, the dropout is not used.

0 generator_residual bool

bool Use residuals for the generator

True generator_activation_out

Optional[List[Tuple[str, int]]] List of activations. Useful with the TabularEncoder

required generator_lr float

float = 2e-4 Generator learning rate, used by the Adam optimizer

0.0002 generator_weight_decay

float = 1e-3 Generator weight decay, used by the Adam optimizer

required generator_opt_betas tuple

tuple = (0.9, 0.999) Generator initial decay rates, used by the Adam Optimizer

(0.9, 0.999) generator_extra_penalty_cbks

List[Callable] Additional loss callabacks for the generator. Used by the TabularGAN for the conditional loss

required discriminator_n_layers_hidden int

int Number of hidden layers in the discriminator

3 discriminator_n_units_hidden int

int Number of hidden units in each layer of the discriminator

300 discriminator_activation str

string, default 'relu' Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.

'leaky_relu' discriminator_batch_norm bool

bool Enable/disable batch norm for the discriminator

False discriminator_dropout float

float Dropout value for the discriminator. If 0, the dropout is not used.

0.1 discriminator_lr float

float Discriminator learning rate, used by the Adam optimizer

0.0002 discriminator_weight_decay

float Discriminator weight decay, used by the Adam optimizer

required discriminator_opt_betas tuple

tuple Initial weight decays for the Adam optimizer

(0.9, 0.999) clipping_value int

int, default 0 Gradients clipping value. Zero disables the feature

0 lambda_gradient_penalty float

float = 10 Weight for the gradient penalty

10 Source code in src/nhssynth/modules/model/models/gan.py
class GAN(Model):\n    \"\"\"\n    Basic GAN implementation.\n\n    Args:\n        n_units_conditional: int\n            Number of conditional units\n        generator_n_layers_hidden: int\n            Number of hidden layers in the generator\n        generator_n_units_hidden: int\n            Number of hidden units in each layer of the Generator\n        generator_activation: string, default 'elu'\n            Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n        generator_n_iter: int\n            Maximum number of iterations in the Generator.\n        generator_batch_norm: bool\n            Enable/disable batch norm for the generator\n        generator_dropout: float\n            Dropout value. If 0, the dropout is not used.\n        generator_residual: bool\n            Use residuals for the generator\n        generator_activation_out: Optional[List[Tuple[str, int]]]\n            List of activations. Useful with the TabularEncoder\n        generator_lr: float = 2e-4\n            Generator learning rate, used by the Adam optimizer\n        generator_weight_decay: float = 1e-3\n            Generator weight decay, used by the Adam optimizer\n        generator_opt_betas: tuple = (0.9, 0.999)\n            Generator initial decay rates, used by the Adam Optimizer\n        generator_extra_penalty_cbks: List[Callable]\n            Additional loss callabacks for the generator. Used by the TabularGAN for the conditional loss\n        discriminator_n_layers_hidden: int\n            Number of hidden layers in the discriminator\n        discriminator_n_units_hidden: int\n            Number of hidden units in each layer of the discriminator\n        discriminator_activation: string, default 'relu'\n            Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n        discriminator_batch_norm: bool\n            Enable/disable batch norm for the discriminator\n        discriminator_dropout: float\n            Dropout value for the discriminator. If 0, the dropout is not used.\n        discriminator_lr: float\n            Discriminator learning rate, used by the Adam optimizer\n        discriminator_weight_decay: float\n            Discriminator weight decay, used by the Adam optimizer\n        discriminator_opt_betas: tuple\n            Initial weight decays for the Adam optimizer\n        clipping_value: int, default 0\n            Gradients clipping value. Zero disables the feature\n        lambda_gradient_penalty: float = 10\n            Weight for the gradient penalty\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        n_units_conditional: int = 0,\n        generator_n_layers_hidden: int = 2,\n        generator_n_units_hidden: int = 250,\n        generator_activation: str = \"leaky_relu\",\n        generator_batch_norm: bool = False,\n        generator_dropout: float = 0,\n        generator_lr: float = 2e-4,\n        generator_residual: bool = True,\n        generator_opt_betas: tuple = (0.9, 0.999),\n        discriminator_n_layers_hidden: int = 3,\n        discriminator_n_units_hidden: int = 300,\n        discriminator_activation: str = \"leaky_relu\",\n        discriminator_batch_norm: bool = False,\n        discriminator_dropout: float = 0.1,\n        discriminator_lr: float = 2e-4,\n        discriminator_opt_betas: tuple = (0.9, 0.999),\n        clipping_value: int = 0,\n        lambda_gradient_penalty: float = 10,\n        **kwargs,\n    ) -> None:\n        super(GAN, self).__init__(*args, **kwargs)\n\n        self.generator_n_units_hidden = generator_n_units_hidden\n        self.n_units_conditional = n_units_conditional\n\n        self.generator = MLP(\n            n_units_in=generator_n_units_hidden + n_units_conditional,\n            n_units_out=self.ncols,\n            n_layers_hidden=generator_n_layers_hidden,\n            n_units_hidden=generator_n_units_hidden,\n            activation=generator_activation,\n            # nonlin_out=generator_activation_out,\n            batch_norm=generator_batch_norm,\n            dropout=generator_dropout,\n            lr=generator_lr,\n            residual=generator_residual,\n            opt_betas=generator_opt_betas,\n        ).to(self.device)\n\n        self.discriminator = MLP(\n            n_units_in=self.ncols + n_units_conditional,\n            n_units_out=1,\n            n_layers_hidden=discriminator_n_layers_hidden,\n            n_units_hidden=discriminator_n_units_hidden,\n            activation=discriminator_activation,\n            activation_out=[(\"none\", 1)],\n            batch_norm=discriminator_batch_norm,\n            dropout=discriminator_dropout,\n            lr=discriminator_lr,\n            opt_betas=discriminator_opt_betas,\n        ).to(self.device)\n\n        self.clipping_value = clipping_value\n        self.lambda_gradient_penalty = lambda_gradient_penalty\n\n        def gen_fake_labels(X: torch.Tensor) -> torch.Tensor:\n            return torch.zeros((len(X),), device=self.device)\n\n        def gen_true_labels(X: torch.Tensor) -> torch.Tensor:\n            return torch.ones((len(X),), device=self.device)\n\n        self.fake_labels_generator = gen_fake_labels\n        self.true_labels_generator = gen_true_labels\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\n            \"n_units_conditional\",\n            \"generator_n_layers_hidden\",\n            \"generator_n_units_hidden\",\n            \"generator_activation\",\n            \"generator_batch_norm\",\n            \"generator_dropout\",\n            \"generator_lr\",\n            \"generator_residual\",\n            \"generator_opt_betas\",\n            \"discriminator_n_layers_hidden\",\n            \"discriminator_n_units_hidden\",\n            \"discriminator_activation\",\n            \"discriminator_batch_norm\",\n            \"discriminator_dropout\",\n            \"discriminator_lr\",\n            \"discriminator_opt_betas\",\n            \"clipping_value\",\n            \"lambda_gradient_penalty\",\n        ]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\"GLoss\", \"DLoss\"]\n\n    def generate(self, N: int, cond: Optional[np.ndarray] = None) -> np.ndarray:\n        N = N or self.nrows\n        self.generator.eval()\n\n        condt: Optional[torch.Tensor] = None\n        if cond is not None:\n            condt = self._check_tensor(cond)\n        with torch.no_grad():\n            return self.metatransformer.inverse_apply(\n                pd.DataFrame(self(N, condt).detach().cpu().numpy(), columns=self.columns)\n            )\n\n    def forward(\n        self,\n        N: int,\n        cond: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if cond is None and self.n_units_conditional > 0:\n            # sample from the original conditional\n            if self._original_cond is None:\n                raise ValueError(\"Invalid original conditional. Provide a valid value.\")\n            cond_idxs = torch.randint(len(self._original_cond), (N,))\n            cond = self._original_cond[cond_idxs]\n\n        if cond is not None and len(cond.shape) == 1:\n            cond = cond.reshape(-1, 1)\n\n        if cond is not None and len(cond) != N:\n            raise ValueError(\"cond length must match N\")\n\n        fixed_noise = torch.randn(N, self.generator_n_units_hidden, device=self.device)\n        fixed_noise = self._append_optional_cond(fixed_noise, cond)\n\n        return self.generator(fixed_noise)\n\n    def _train_epoch_generator(\n        self,\n        X: torch.Tensor,\n        cond: Optional[torch.Tensor],\n    ) -> float:\n        # Update the G network\n        self.generator.train()\n        self.generator.optimizer.zero_grad()\n\n        real_X_raw = X.to(self.device)\n        real_X = self._append_optional_cond(real_X_raw, cond)\n        batch_size = len(real_X)\n\n        noise = torch.randn(batch_size, self.generator_n_units_hidden, device=self.device)\n        noise = self._append_optional_cond(noise, cond)\n\n        fake_raw = self.generator(noise)\n        fake = self._append_optional_cond(fake_raw, cond)\n\n        output = self.discriminator(fake).squeeze().float()\n        # Calculate G's loss based on this output\n        errG = -torch.mean(output)\n        if hasattr(self, \"generator_extra_penalty_cbks\"):\n            for extra_loss in self.generator_extra_penalty_cbks:\n                errG += extra_loss(\n                    real_X_raw,\n                    fake_raw,\n                    cond=cond,\n                )\n\n        # Calculate gradients for G\n        errG.backward()\n\n        # Update G\n        if self.clipping_value > 0:\n            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.clipping_value)\n        self.generator.optimizer.step()\n\n        if torch.isnan(errG):\n            raise RuntimeError(\"NaNs detected in the generator loss\")\n\n        # Return loss\n        return errG.item()\n\n    def _train_epoch_discriminator(\n        self,\n        X: torch.Tensor,\n        cond: Optional[torch.Tensor],\n    ) -> float:\n        # Update the D network\n        self.discriminator.train()\n\n        errors = []\n\n        batch_size = min(self.batch_size, len(X))\n\n        # Train with all-real batch\n        real_X = X.to(self.device)\n        real_X = self._append_optional_cond(real_X, cond)\n\n        real_labels = self.true_labels_generator(X).to(self.device).squeeze()\n        real_output = self.discriminator(real_X).squeeze().float()\n\n        # Train with all-fake batch\n        noise = torch.randn(batch_size, self.generator_n_units_hidden, device=self.device)\n        noise = self._append_optional_cond(noise, cond)\n\n        fake_raw = self.generator(noise)\n        fake = self._append_optional_cond(fake_raw, cond)\n\n        fake_labels = self.fake_labels_generator(fake_raw).to(self.device).squeeze().float()\n        fake_output = self.discriminator(fake.detach()).squeeze()\n\n        # Compute errors. Some fake inputs might be marked as real for privacy guarantees.\n\n        real_real_output = real_output[(real_labels * real_output) != 0]\n        real_fake_output = fake_output[(fake_labels * fake_output) != 0]\n        errD_real = torch.mean(torch.concat((real_real_output, real_fake_output)))\n\n        fake_real_output = real_output[((1 - real_labels) * real_output) != 0]\n        fake_fake_output = fake_output[((1 - fake_labels) * fake_output) != 0]\n        errD_fake = torch.mean(torch.concat((fake_real_output, fake_fake_output)))\n\n        penalty = self._loss_gradient_penalty(\n            real_samples=real_X,\n            fake_samples=fake,\n            batch_size=batch_size,\n        )\n        errD = -errD_real + errD_fake\n\n        self.discriminator.optimizer.zero_grad()\n        if isinstance(self, DPMixin):\n            # Adversarial loss\n            # 1. split fwd-bkwd on fake and real images into two explicit blocks.\n            # 2. no need to compute per_sample_gardients on fake data, disable hooks.\n            # 3. re-enable hooks to obtain per_sample_gardients for real data.\n            # fake fwd-bkwd\n            self.discriminator.disable_hooks()\n            penalty.backward(retain_graph=True)\n            errD_fake.backward(retain_graph=True)\n\n            self.discriminator.enable_hooks()\n            errD_real.backward()  # HACK: calling bkwd without zero_grad() accumulates param gradients\n        else:\n            penalty.backward(retain_graph=True)\n            errD.backward()\n\n        # Update D\n        if self.clipping_value > 0:\n            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.clipping_value)\n        self.discriminator.optimizer.step()\n\n        errors.append(errD.item())\n\n        if np.isnan(np.mean(errors)):\n            raise RuntimeError(\"NaNs detected in the discriminator loss\")\n\n        return np.mean(errors)\n\n    def _train_epoch(self) -> Tuple[float, float]:\n        for data in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n            cond: Optional[torch.Tensor] = None\n            if self.n_units_conditional > 0:\n                X, cond = data\n            else:\n                X = data[0]\n\n            losses = {\n                \"DLoss\": self._train_epoch_discriminator(X, cond),\n                \"GLoss\": self._train_epoch_generator(X, cond),\n            }\n            self._record_metrics(losses)\n\n        return np.mean(self.metrics[\"GLoss\"][-len(self.data_loader) :]), np.mean(\n            self.metrics[\"DLoss\"][-len(self.data_loader) :]\n        )\n\n    def train(\n        self,\n        num_epochs: int = 100,\n        patience: int = 5,\n        displayed_metrics: list[str] = [\"GLoss\", \"DLoss\"],\n    ) -> tuple[int, dict[str, np.ndarray]]:\n        self._start_training(num_epochs, patience, displayed_metrics)\n\n        for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n            losses = self._train_epoch()\n            if self._check_patience(epoch, losses[0]) and self._check_patience(epoch, losses[1]):\n                num_epochs = epoch + 1\n                break\n\n        self._finish_training(num_epochs)\n        return (num_epochs, self.metrics)\n\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\n        if isinstance(X, torch.Tensor):\n            return X.to(self.device)\n        else:\n            return torch.from_numpy(np.asarray(X)).to(self.device)\n\n    def _loss_gradient_penalty(\n        self,\n        real_samples: torch.tensor,\n        fake_samples: torch.Tensor,\n        batch_size: int,\n    ) -> torch.Tensor:\n        \"\"\"Calculates the gradient penalty loss for WGAN GP\"\"\"\n        # Random weight term for interpolation between real and fake samples\n        alpha = torch.rand([batch_size, 1]).to(self.device)\n        # Get random interpolation between real and fake samples\n        interpolated = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)\n        d_interpolated = self.discriminator(interpolated).squeeze()\n        labels = torch.ones((len(interpolated),), device=self.device)\n\n        # Get gradient w.r.t. interpolates\n        gradients = torch.autograd.grad(\n            outputs=d_interpolated,\n            inputs=interpolated,\n            grad_outputs=labels,\n            create_graph=True,\n            retain_graph=True,\n            only_inputs=True,\n            allow_unused=True,\n        )[0]\n        gradients = gradients.view(gradients.size(0), -1)\n        gradient_penalty = ((gradients.norm(2, dim=-1) - 1) ** 2).mean()\n        return self.lambda_gradient_penalty * gradient_penalty\n\n    def _append_optional_cond(self, X: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:\n        if cond is None:\n            return X\n\n        return torch.cat([X, cond], dim=1)\n
"},{"location":"reference/modules/model/models/vae/","title":"vae","text":""},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.Decoder","title":"Decoder","text":"

Bases: Module

Decoder, takes in z and outputs reconstruction

Source code in src/nhssynth/modules/model/models/vae.py
class Decoder(nn.Module):\n    \"\"\"Decoder, takes in z and outputs reconstruction\"\"\"\n\n    def __init__(\n        self,\n        output_dim: int,\n        latent_dim: int,\n        hidden_dim: int,\n        activation: str,\n        learning_rate: float,\n        shared_optimizer: bool,\n    ) -> None:\n        super().__init__()\n        activation = ACTIVATION_FUNCTIONS[activation]\n        self.net = nn.Sequential(\n            nn.Linear(latent_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, output_dim),\n        )\n        if not shared_optimizer:\n            self.optim = torch.optim.Adam(self.parameters(), lr=learning_rate)\n\n    def forward(self, z):\n        return self.net(z)\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.Encoder","title":"Encoder","text":"

Bases: Module

Encoder, takes in x and outputs mu_z, sigma_z (diagonal Gaussian variational posterior assumed)

Source code in src/nhssynth/modules/model/models/vae.py
class Encoder(nn.Module):\n    \"\"\"Encoder, takes in x and outputs mu_z, sigma_z (diagonal Gaussian variational posterior assumed)\"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        latent_dim: int,\n        hidden_dim: int,\n        activation: str,\n        learning_rate: float,\n        shared_optimizer: bool,\n    ) -> None:\n        super().__init__()\n        activation = ACTIVATION_FUNCTIONS[activation]\n        self.latent_dim = latent_dim\n        self.net = nn.Sequential(\n            nn.Linear(input_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, 2 * latent_dim),\n        )\n        if not shared_optimizer:\n            self.optim = torch.optim.Adam(self.parameters(), lr=learning_rate)\n\n    def forward(self, x):\n        outs = self.net(x)\n        mu_z = outs[:, : self.latent_dim]\n        logsigma_z = outs[:, self.latent_dim :]\n        return mu_z, logsigma_z\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.VAE","title":"VAE","text":"

Bases: Model

A Variational Autoencoder (VAE) model. Accepts Model arguments as well as the following:

Parameters:

Name Type Description Default encoder_latent_dim int

The dimensionality of the latent space.

256 encoder_hidden_dim int

The dimensionality of the hidden layers in the encoder.

256 encoder_activation str

The activation function to use in the encoder.

'leaky_relu' encoder_learning_rate float

The learning rate for the encoder.

0.001 decoder_latent_dim int

The dimensionality of the hidden layers in the decoder.

256 decoder_hidden_dim int

The dimensionality of the hidden layers in the decoder.

32 decoder_activation str

The activation function to use in the decoder.

'leaky_relu' decoder_learning_rate float

The learning rate for the decoder.

0.001 shared_optimizer bool

Whether to use a shared optimizer for the encoder and decoder.

True Source code in src/nhssynth/modules/model/models/vae.py
class VAE(Model):\n    \"\"\"\n    A Variational Autoencoder (VAE) model. Accepts [`Model`][nhssynth.modules.model.common.model.Model] arguments as well as the following:\n\n    Args:\n        encoder_latent_dim: The dimensionality of the latent space.\n        encoder_hidden_dim: The dimensionality of the hidden layers in the encoder.\n        encoder_activation: The activation function to use in the encoder.\n        encoder_learning_rate: The learning rate for the encoder.\n        decoder_latent_dim: The dimensionality of the hidden layers in the decoder.\n        decoder_hidden_dim: The dimensionality of the hidden layers in the decoder.\n        decoder_activation: The activation function to use in the decoder.\n        decoder_learning_rate: The learning rate for the decoder.\n        shared_optimizer: Whether to use a shared optimizer for the encoder and decoder.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        encoder_latent_dim: int = 256,\n        encoder_hidden_dim: int = 256,\n        encoder_activation: str = \"leaky_relu\",\n        encoder_learning_rate: float = 1e-3,\n        decoder_latent_dim: int = 256,\n        decoder_hidden_dim: int = 32,\n        decoder_activation: str = \"leaky_relu\",\n        decoder_learning_rate: float = 1e-3,\n        shared_optimizer: bool = True,\n        **kwargs,\n    ) -> None:\n        super(VAE, self).__init__(*args, **kwargs)\n\n        self.shared_optimizer = shared_optimizer\n        self.encoder = Encoder(\n            input_dim=self.ncols,\n            latent_dim=encoder_latent_dim,\n            hidden_dim=encoder_hidden_dim,\n            activation=encoder_activation,\n            learning_rate=encoder_learning_rate,\n            shared_optimizer=self.shared_optimizer,\n        ).to(self.device)\n        self.decoder = Decoder(\n            output_dim=self.ncols,\n            latent_dim=decoder_latent_dim,\n            hidden_dim=decoder_hidden_dim,\n            activation=decoder_activation,\n            learning_rate=decoder_learning_rate,\n            shared_optimizer=self.shared_optimizer,\n        ).to(self.device)\n        self.noiser = Noiser(\n            len(self.single_column_indices),\n        ).to(self.device)\n        if self.shared_optimizer:\n            assert (\n                encoder_learning_rate == decoder_learning_rate\n            ), \"If `shared_optimizer` is True, `encoder_learning_rate` must equal `decoder_learning_rate`\"\n            self.optim = torch.optim.Adam(\n                list(self.encoder.parameters()) + list(self.decoder.parameters()),\n                lr=encoder_learning_rate,\n            )\n            self.zero_grad = self.optim.zero_grad\n            self.step = self.optim.step\n        else:\n            self.zero_grad = lambda: (self.encoder.optim.zero_grad(), self.decoder.optim.zero_grad())\n            self.step = lambda: (self.encoder.optim.step(), self.decoder.optim.step())\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\n            \"encoder_latent_dim\",\n            \"encoder_hidden_dim\",\n            \"encoder_activation\",\n            \"encoder_learning_rate\",\n            \"decoder_latent_dim\",\n            \"decoder_hidden_dim\",\n            \"decoder_activation\",\n            \"decoder_learning_rate\",\n            \"shared_optimizer\",\n        ]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\n            \"ELBO\",\n            \"KLD\",\n            \"ReconstructionLoss\",\n            \"CategoricalLoss\",\n            \"NumericalLoss\",\n        ]\n\n    def reconstruct(self, X):\n        mu_z, logsigma_z = self.encoder(X)\n        x_recon = self.decoder(mu_z)\n        return x_recon\n\n    def generate(self, N: Optional[int] = None) -> pd.DataFrame:\n        N = N or self.nrows\n        z_samples = torch.randn_like(torch.ones((N, self.encoder.latent_dim)), device=self.device)\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n            x_gen = self.decoder(z_samples)\n        x_gen_ = torch.ones_like(x_gen, device=self.device)\n\n        if self.multi_column_indices != [[]]:\n            for cat_idxs in self.multi_column_indices:\n                x_gen_[:, cat_idxs] = torch.distributions.one_hot_categorical.OneHotCategorical(\n                    logits=x_gen[:, cat_idxs]\n                ).sample()\n\n        x_gen_[:, self.single_column_indices] = x_gen[:, self.single_column_indices] + torch.exp(\n            self.noiser(x_gen[:, self.single_column_indices])\n        ) * torch.randn_like(x_gen[:, self.single_column_indices])\n        if torch.cuda.is_available():\n            x_gen_ = x_gen_.cpu()\n        return self.metatransformer.inverse_apply(pd.DataFrame(x_gen_.detach(), columns=self.columns))\n\n    def loss(self, X):\n        mu_z, logsigma_z = self.encoder(X)\n\n        p = Normal(torch.zeros_like(mu_z), torch.ones_like(mu_z))\n        q = Normal(mu_z, torch.exp(logsigma_z))\n\n        kld = torch.sum(torch.distributions.kl_divergence(q, p))\n\n        s = torch.randn_like(mu_z)\n        z_samples = mu_z + s * torch.exp(logsigma_z)\n\n        x_recon = self.decoder(z_samples)\n\n        categoric_loglik = 0\n\n        if self.multi_column_indices != [[]]:\n            for cat_idxs in self.multi_column_indices:\n                categoric_loglik += -torch.nn.functional.cross_entropy(\n                    x_recon[:, cat_idxs],\n                    torch.max(X[:, cat_idxs], 1)[1],\n                ).sum()\n\n        gauss_loglik = 0\n        if self.single_column_indices:\n            gauss_loglik = (\n                Normal(\n                    loc=x_recon[:, self.single_column_indices],\n                    scale=torch.exp(self.noiser(x_recon[:, self.single_column_indices])),\n                )\n                .log_prob(X[:, self.single_column_indices])\n                .sum()\n            )\n\n        reconstruction_loss = -(categoric_loglik + gauss_loglik)\n\n        elbo = kld + reconstruction_loss\n\n        return {\n            \"ELBO\": elbo / X.size()[0],\n            \"ReconstructionLoss\": reconstruction_loss / X.size()[0],\n            \"KLD\": kld / X.size()[0],\n            \"CategoricalLoss\": categoric_loglik / X.size()[0],\n            \"NumericalLoss\": gauss_loglik / X.size()[0],\n        }\n\n    def train(\n        self,\n        num_epochs: int = 100,\n        patience: int = 5,\n        displayed_metrics: list[str] = [\"ELBO\"],\n    ) -> tuple[int, dict[str, list[float]]]:\n        \"\"\"\n        Train the model.\n\n        Args:\n            num_epochs: Number of epochs to train for.\n            patience: Number of epochs to wait for improvement before early stopping.\n            displayed_metrics: List of metrics to display during training.\n\n        Returns:\n            The number of epochs trained for and a dictionary of the tracked metrics.\n        \"\"\"\n        self._start_training(num_epochs, patience, displayed_metrics)\n\n        self.encoder.train()\n        self.decoder.train()\n        self.noiser.train()\n\n        for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n            for (Y_subset,) in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n                self.zero_grad()\n                with warnings.catch_warnings():\n                    warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n                    losses = self.loss(Y_subset.to(self.device))\n                losses[\"ELBO\"].backward()\n                self.step()\n                self._record_metrics(losses)\n\n            elbo = np.mean(self.metrics[\"ELBO\"][-len(self.data_loader) :])\n            if self._check_patience(epoch, elbo):\n                num_epochs = epoch + 1\n                break\n\n        self._finish_training(num_epochs)\n        return (num_epochs, self.metrics)\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.VAE.train","title":"train(num_epochs=100, patience=5, displayed_metrics=['ELBO'])","text":"

Train the model.

Parameters:

Name Type Description Default num_epochs int

Number of epochs to train for.

100 patience int

Number of epochs to wait for improvement before early stopping.

5 displayed_metrics list[str]

List of metrics to display during training.

['ELBO']

Returns:

Type Description tuple[int, dict[str, list[float]]]

The number of epochs trained for and a dictionary of the tracked metrics.

Source code in src/nhssynth/modules/model/models/vae.py
def train(\n    self,\n    num_epochs: int = 100,\n    patience: int = 5,\n    displayed_metrics: list[str] = [\"ELBO\"],\n) -> tuple[int, dict[str, list[float]]]:\n    \"\"\"\n    Train the model.\n\n    Args:\n        num_epochs: Number of epochs to train for.\n        patience: Number of epochs to wait for improvement before early stopping.\n        displayed_metrics: List of metrics to display during training.\n\n    Returns:\n        The number of epochs trained for and a dictionary of the tracked metrics.\n    \"\"\"\n    self._start_training(num_epochs, patience, displayed_metrics)\n\n    self.encoder.train()\n    self.decoder.train()\n    self.noiser.train()\n\n    for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n        for (Y_subset,) in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n            self.zero_grad()\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n                losses = self.loss(Y_subset.to(self.device))\n            losses[\"ELBO\"].backward()\n            self.step()\n            self._record_metrics(losses)\n\n        elbo = np.mean(self.metrics[\"ELBO\"][-len(self.data_loader) :])\n        if self._check_patience(epoch, elbo):\n            num_epochs = epoch + 1\n            break\n\n    self._finish_training(num_epochs)\n    return (num_epochs, self.metrics)\n
"},{"location":"reference/modules/plotting/","title":"plotting","text":""},{"location":"reference/modules/plotting/io/","title":"io","text":""},{"location":"reference/modules/plotting/io/#nhssynth.modules.plotting.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_typed, fn_evaluations, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_typed str

The name of the typed data file.

required fn_evaluations str

The name of the file containing the evaluation bundle.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/plotting/io.py
def check_input_paths(fn_dataset: str, fn_typed: str, fn_evaluations: str, dir_experiment: Path) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_typed: The name of the typed data file.\n        fn_evaluations: The name of the file containing the evaluation bundle.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset, fn_typed, fn_evaluations = io.consistent_endings([fn_dataset, fn_typed, fn_evaluations])\n    fn_typed, fn_evaluations = io.potential_suffixes([fn_typed, fn_evaluations], fn_dataset)\n    io.warn_if_path_supplied([fn_dataset, fn_typed, fn_evaluations], dir_experiment)\n    io.check_exists([fn_typed], dir_experiment)\n    return fn_dataset, fn_typed, fn_evaluations\n
"},{"location":"reference/modules/plotting/io/#nhssynth.modules.plotting.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, DataFrame, dict[str, dict[str, Any]]]

The data, metadata and metatransformer.

Source code in src/nhssynth/modules/plotting/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, pd.DataFrame, dict[str, dict[str, Any]]]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The data, metadata and metatransformer.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"typed\", \"evaluations\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"typed\"],\n            args.module_handover[\"evaluations\"],\n        )\n    else:\n        fn_dataset, fn_typed, fn_evaluations = check_input_paths(\n            args.dataset, args.typed, args.evaluations, dir_experiment\n        )\n\n        with open(dir_experiment / fn_typed, \"rb\") as f:\n            real_data = pickle.load(f)\n        with open(dir_experiment / fn_evaluations, \"rb\") as f:\n            evaluations = pickle.load(f)\n\n        return fn_dataset, real_data, evaluations\n
"},{"location":"reference/modules/plotting/plots/","title":"plots","text":""},{"location":"reference/modules/plotting/plots/#nhssynth.modules.plotting.plots.factorize_all_categoricals","title":"factorize_all_categoricals(df)","text":"

Factorize all categorical columns in a dataframe.

Source code in src/nhssynth/modules/plotting/plots.py
def factorize_all_categoricals(\n    df: pd.DataFrame,\n) -> pd.DataFrame:\n    \"\"\"Factorize all categorical columns in a dataframe.\"\"\"\n    for col in df.columns:\n        if df[col].dtype == \"object\":\n            df[col] = pd.factorize(df[col])[0]\n        elif df[col].dtype == \"datetime64[ns]\":\n            df[col] = pd.to_numeric(df[col])\n        min_val = df[col].min()\n        max_val = df[col].max()\n        df[col] = (df[col] - min_val) / (max_val - min_val)\n\n    return df\n
"},{"location":"reference/modules/plotting/run/","title":"run","text":""},{"location":"reference/modules/structure/","title":"structure","text":""},{"location":"reference/modules/structure/run/","title":"run","text":""}]} \ No newline at end of file +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"NHS Synth","text":"

This is a package for generating useful synthetic data, audited and assessed along the dimensions of utility, privacy and fairness. Currently, the main focus of the package in its beta stage is to experiment with different model architectures to find which are the most promising for real-world usage.

See the User Guide to get started with running an experiment with the package.

See the Development Guide and Code Reference to get started with contributing to the package.

"},{"location":"development_guide/","title":"Development guide","text":"

This document aims to provide a comprehensive set of instructions for continuing development of this package. Good knowledge of Python development is assumed. Some ways of working are subjective and preferential; as such we try to be as minimal in our proscription of other methods as possible.

"},{"location":"development_guide/#development-environment-setup","title":"Development environment setup","text":""},{"location":"development_guide/#python","title":"Python","text":"

The package currently supports major versions 3.9, 3.10 and 3.11 of Python. We recommend installing all of these versions; at minimum the latest supported version of Python should be used. Many people use pyenv for managing multiple python versions. On MacOS homebrew is a good, less invasive option for this (provided you then use a virtual environment manager too). For virtual environment management, we recommend Python's in-built venv functionality, but conda or some similar system would suffice (note that in the section below it may not be necessary to use any specific virtual environment management at all depending on the setup of Poetry).

"},{"location":"development_guide/#poetry","title":"Poetry","text":"

We use Poetry to manage dependencies and the actual packaging and publishing of NHSSynth to PyPI. Poetry is a more robust alternative to a requirements.txt file, allowing for grouped dependencies and advanced build options. Rather than freezing a specific pip state, Poetry only specifies the top-level dependencies and then handles the resolution and installation of the latest compatible versions of the full dependency tree per these top-level dependencies. See the pyproject.toml in the GitHub repository and Poetry's documentation for further context.

Once Poetry is installed (in your preferred way per the instructions on their website), you can choose one of two options:

  1. Allow poetry to control virtual environments in their proprietary way), such that when you install and develop the package poetry will automatically create a virtual environment for you.

  2. Change poetry's configuration to manage your own virtual environments:

    poetry config virtualenvs.create false\npoetry config virtualenvs.in-project false\n

    In this setup, a virtual environment can be be instantiated and activated in whichever way you prefer. For example, using venv:

    python3.11 -m venv nhssynth-3.11\nsource nhssynth-3.11/bin/activate\n
"},{"location":"development_guide/#package-installation","title":"Package installation","text":"

At this point, the project dependencies can be installed via poetry install --with dev (add optional flags: --with aux to work with the auxiliary notebooks, --with docs to work with the documentation). This will install the package in editable mode, meaning that changes to the source code will be reflected in the installed package without needing to reinstall it. Note that if you are using your own virtual environment, you will need to activate it before running this command.

You can then interact with the package in one of two ways:

  1. Via the CLI module, which is accessed using the nhssynth command, e.g.

    poetry run nhssynth ...\n

    Note that you can omit the poetry run part and just type nhssynth if you followed the optional steps above to manage and activate your own virtual environment, or if you have executed poetry shell beforehand. 2. Through directly importing parts of the package to use in an existing project (from nhssynth.modules... import ...).

"},{"location":"development_guide/#secure-mode","title":"Secure mode","text":"

Note that in order to train a generator in secure mode (see the documentation for details) the PyTorch extension package csprng must be installed separately. Currently this package's dependencies are not compatible with recent versions of PyTorch (the author's plan on rectifying this - watch this space), so you will need to install it manually, you can do this in your environment by running:

git clone git@github.com:pytorch/csprng.git\ncd csprng\ngit branch release \"v0.2.2-rc1\"\ngit checkout release\npython setup.py install\n
"},{"location":"development_guide/#coding-practices","title":"Coding practices","text":""},{"location":"development_guide/#style","title":"Style","text":"

We use black for code formatting. This is a fairly opinionated formatter, but it is widely used and has a good reputation. We also use ruff to manage imports and lint the code. Both of these tools are run automatically via pre-commit hooks. Ensure you have installed the package with the dev group of dependencies and then run the following command to install the hooks:

pre-commit install\n

Note that you may need to pre-pend this command with poetry run if you are not using your own virtual environment.

This will ensure that your code conforms to the two formatters' / linters' requirements each time you commit to a branch. black and ruff are also run as part of the CI workflow discussed below, such that even without these hooks, the code will be checked and raise an error on GitHub if it is not formatted consistently.

Configuration for both packages can be found in the pyproject.toml, this configuration should be picked up automatically by both the pre-commit hooks and your IDE / running them manually in the command line. The main configuration is as follows:

[tool.black]\nline-length = 120\n\n[tool.ruff]\ninclude = [\"*.py\", \"*.pyi\", \"**/pyproject.toml\", \"*.ipynb\"]\nselect = [\"E4\", \"E7\", \"E9\", \"F\", \"C90\", \"I\"]\n\n[tool.ruff.per-file-ignores]\n\"src/nhssynth/common/constants.py\" = [\"F403\", \"F405\"]\n\n[tool.ruff.isort]\nknown-first-party = [\"nhssynth\"]\n

This ensure that absolute imports from NHSSynth are sorted separately from the rest of the imports in a file.

There are a number of other hooks used as part of this repositories pre-commit, including one that automatically mirrors the poetry version of these packages in the dev per the list of supported packages and .poetry-sync-db.json. Roughly, these other hooks ensure correct formatting of .yaml and .toml files, checks for large files being added to a commit, strips notebook output from the files, and fixes whitespace and end-of-file issues. These are mostly consistent with the NHSx analytics project template's hooks

"},{"location":"development_guide/#documentation","title":"Documentation","text":"

There should be Google-style docstrings on all non-trivial functions and classes. Ideally a docstring should take the form:

def func(arg1: type1, arg2: type2) -> returntype:\n    \"\"\"\n    One-line summary of the function.\n    AND / OR\n    Longer description of the function, including any caveats or assumptions where appropriate.\n\n    Args:\n        arg1: Description of arg1.\n        arg2: Description of arg2.\n\n    Returns:\n        Description of the return value.\n    \"\"\"\n    ...\n

These docstrings are then compiled into a full API documentation tree as part of a larger MkDocs documentation site hosted via GitHub (the one you are reading right now!). This process is derived from this tutorial.

The MkDocs page is built using the mkdocs-material theme. The documentation is built and hosted automatically via GitHub Pages.

The other parts of this site comprise markdown documents in the docs folder. Adding new pages is handled in the mkdocs.yml file as in any other Material MkDocs site. See their documentation if more complex changes to the site are required.

"},{"location":"development_guide/#testing","title":"Testing","text":"

We use tox to manage the execution of tests for the package against multiple versions of Python, and to ensure that they are being run in a clean environment. To run the tests, simply execute tox in the root directory of the repository. This will run the tests against all supported versions of Python. To run the tests against a specific version of Python, use tox -e py311 (or py310 or py39).

"},{"location":"development_guide/#configuration","title":"Configuration","text":"

See the tox.ini file for more information on the testing configuration. We follow the Poetry documentation on tox support to ensure that for each version of Python, tox will create an sdist package of the project and use pip to install it in a fresh environment. Thus, dependencies are resolved by pip in the first place and then afterwards updated to the locked dependencies in poetry.lock by running poetry install ... in this fresh environment. The tests are then run using poetry pytest, which is configured in the pyproject.toml file. This configuration is fairly minimal: simply specifying the testing directory as the tests folder and filtering some known warnings.

[tool.pytest.ini_options]\ntestpaths = \"tests\"\nfilterwarnings = [\"ignore::DeprecationWarning:pkg_resources\"]\n

We can also use coverage to check the test coverage of the package. This is configured in the pyproject.toml file as follows:

[tool.coverage.run]\nsource = [\"src/nhssynth/cli\", \"src/nhssynth/common\", \"src/nhssynth/modules\"]\nomit = [\n    \"src/nhssynth/common/debugging.py\",\n]\n

We omit debugging.py as it is a wrapper for reading full trace-backs of warnings and not to be imported directly.

"},{"location":"development_guide/#adding-tests","title":"Adding Tests","text":"

We use the pytest framework for testing. The testing directory structure mirrors that of src. The usual testing practices apply.

"},{"location":"development_guide/#releases","title":"Releases","text":""},{"location":"development_guide/#version-management","title":"Version management","text":"

The package's version should be updated following the semantic versioning framework. The package is currently in a pre-release state, such that major version 1.0.0 should only be tagged once the package is functionally complete and stable.

To update the package's metadata, we can use Poetry's version command:

poetry version <version>\n

We can then commit and push the changes to the version file, and create a new tag:

git add pyproject.toml\ngit commit -m \"Bump version to <version>\"\ngit push\n

We should then tag the release using GitHub's CLI (or manually via git if you prefer):

gh release create <version> --generate-notes\n

This will create a new release on GitHub, and will automatically generate a changelog based on the commit messages and PR's closed since the last release. This changelog can then be edited to add more detail if necessary.

"},{"location":"development_guide/#building-and-publishing-to-pypi","title":"Building and publishing to PyPI","text":"

Poetry offers not only dependency management, but also a simple way to build and distribute the package.

After tagging a release per the section above, we can build the package using Poetry's build command:

poetry build\n

This will create a dist folder containing the built package. To publish this to PyPI, we can use the publish command:

poetry publish\n

This will prompt for PyPI credentials, and then publish the package. Note that this will only work if you have been added as a Maintainer of the package on PyPI.

It might be preferable at some point in the future to set up Trusted Publisher Management via OpenID Connect (OIDC) to allow for automated publishing of the package via a GitHub workflow. See the \"Publishing\" tab of NHSSynth's project management panel on PyPI to set this up.

"},{"location":"development_guide/#github","title":"GitHub","text":""},{"location":"development_guide/#continuous-integration","title":"Continuous integration","text":"

We use GitHub Actions for continuous integration. The different workflows comprising this can be found in the .github/workflows folder. In general, the CI workflow is triggered on every push to the main or a feature branch - as appropriate - and runs tests against all supported versions of Python. It also runs black and ruff to check that the code is formatted correctly, and builds the documentation site.

There are also scripts to update the dynamic badges in the README. These work via a gist associated with the repository. It is not easy to transfer ownership of this process, so if they break please feel free to contact me.

"},{"location":"development_guide/#branching","title":"Branching","text":"

We encourage the use of the Gitflow branching model for development. This means that the main branch is always in a stable state, and that all development work is done on feature branches. These feature branches are then merged into main via pull requests. The main branch is protected, such that pull requests must be reviewed and approved before they can be merged.

At minimum, the main branches protection should be maintained, and roughly one branch per issue should be used. Ensure that all of the CI checks pass before merging.

"},{"location":"development_guide/#security-and-vulnerability-management","title":"Security and vulnerability management","text":"

The GitHub repository for the package has Dependabot, code scanning, and other security features enabled. These should be monitored continuously and any issues resolved as soon as possible. When issues of this type require a specific version of a dependency to be specified (and it is one that is not already amongst the dependency groups of the package), the version should be referenced as part of the security group of dependencies (i.e. with poetry add <package> --group security) and a new release created (see above).

"},{"location":"downstream_tasks/","title":"Defining a downstream task","text":"

It is likely that a synthetic dataset may be associated with specific modelling efforts or metrics that are not included in the general suite of evaluation tools supported more explicitly by this package. Additionally, analyses on model outputs for bias and fairness provided via Aequitas require some basis of predictions on which to perform the analysis. For these reasons, we provide a simple interface for defining a custom downstream task.

All downstream tasks are to be located in a folder named tasks in the working directory of the project, with subfolders for each dataset, i.e. the tasks associated with the support dataset should be located in the tasks/support directory.

The interface is then quite simple:

  • There should be a function called run that takes a single argument: dataset (additional arguments could be provided with some further configuration if there is a need for this)
  • The run function should fit a model and / or calculate some metric(s) on the dataset.
  • It should then return predicted probabilities for the outcome variable(s) in the dataset and a dictionary of metrics.
  • The file should contain a top-level variable containing an instantiation of the nhssynth Task class.

See the example below of a logistic regression model fit on the support dataset with the event variable as the outcome and rocauc as the metric of interest:

import pandas as pd\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import roc_auc_score\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.preprocessing import StandardScaler\n\nfrom nhssynth.modules.evaluation.tasks import Task\n\n\ndef run(dataset: pd.DataFrame) -> tuple[pd.DataFrame, dict]:\n    # Split the dataset into features and target\n    target = \"event\"\n\n    data = dataset.dropna()\n    X, y = data.drop([\"dob\", \"x3\", target], axis=1), data[target]\n    X_train, X_test, y_train, y_test = train_test_split(\n        StandardScaler().fit_transform(X), y, test_size=0.33, random_state=42\n    )\n\n    lr = LogisticRegression()\n    lr.fit(X_train, y_train)\n\n    # Get the predicted probabilities and predictions\n    probs = pd.DataFrame(lr.predict_proba(X_test)[:, 1], columns=[f\"lr_{target}_prob\"])\n\n    rocauc = roc_auc_score(y_test, probs)\n\n    return probs, {\"rocauc_lr\": rocauc}\n\n\ntask = Task(\"Logistic Regression on 'event'\", run, supports_aequitas=True)\n

Note the highlighted lines above:

  1. The Task class has been imported from nhssynth.modules.evaluations.tasks
  2. The run function should accept one argument and return a tuple
  3. The second element of this tuple should be a dictionary labelling each metric of interest (this name will be used in the dashboard as identification so ensure it is unique to the experiment)
  4. The task should be instantiated with a name, the run function and a boolean indicating whether the task supports Aequitas analysis, if the task does not support Aequitas analysis, then the first element of the tuple will not be used and None can be returned instead.

The rest of this file can contain any arbitrary code that runs within these constraints, this could be a simple model as above, or a more complex pipeline of transformations and models to match a pre-existing workflow.

"},{"location":"getting_started/","title":"Getting Started","text":""},{"location":"getting_started/#running-an-experiment","title":"Running an experiment","text":"

This package offers two easy ways to run reproducible and highly-configurable experiments. The following sections describe how to use each of these two methods.

"},{"location":"getting_started/#via-the-cli","title":"Via the CLI","text":"

The CLI is the easiest way to quickly run an experiment. It is designed to be as simple as possible, whilst still offering a high degree of configurability. An example command to run a full pipeline experiment is:

nhssynth pipeline \\\n    --experiment-name test \\\n    --dataset support \\\n    --seed 123 \\\n    --architecture DPVAE PATEGAN DECAF \\\n    --repeats 3 \\\n    --downstream-tasks \\\n    --column-similarity-metrics CorrelationSimilarity ContingencySimilarity \\\n    --column-shape-metrics KSComplement TVComplement \\\n    --boundary-metrics BoundaryAdherence \\\n    --synthesis-metrics NewRowSynthesis \\\n    --divergence-metrics ContinuousKLDivergence DiscreteKLDivergence\n

This will run a full pipeline experiment on the support dataset in the data directory. The outputs of the experiment will be recorded in a folder named test (corresponding to the experiment name) in the experiments directory.

In total, three different model architectures will be trained three times each with their default configurations. The resulting generated synthetic datasets will be evaluated via the downstream tasks in tasks/support alongside the metrics specified in the command. A dashboard will then be built automatically to exhibit the results.

The components of the run are persistent to the experiment's folder. Suppose you have already run this experiment and want to add some new evaluations. You do not have to re-run the entire experiment, you can simply run:

nhssynth evaluation -e test -d support -s 123 --coverage-metrics RangeCoverage CategoryCoverage\nnhssynth dashboard -e test -d support\n

This will regenerate the dashboard with a different set of metrics corresponding to the arguments passed to evaluation. Note that the --experiment-name and --dataset arguments are required for all commands, as they are used to identify the experiment and ensure reproducibility.

"},{"location":"getting_started/#via-a-configuration-file","title":"Via a configuration file","text":"

A yaml configuration file placed in the config folder can be used to get the same result as the above:

seed: 123\nexperiment_name: test\nrun_type: pipeline\nmodel:\n  architecture:\n    - DPVAE\n    - DPGAN\n    - DECAF\n  max_grad_norm: 5.0\n  secure_mode: false\n  repeats: 4\nevaluation:\n  downstream_tasks: true\n  column_shape_metrics:\n  - KSComplement\n  - TVComplement\n  column_similarity_metrics:\n  - CorrelationSimilarity\n  - ContingencySimilarity\n  boundary_metrics:\n  - BoundaryAdherence\n  synthesis_metrics:\n  - NewRowSynthesis\n  divergence_metrics:\n  - ContinuousKLDivergence\n  - DiscreteKLDivergence\n

Once saved as run_pipeline.yaml in the config directory, the package can be run under the configuration laid out in the file via:

nhssynth config -c run_pipeline\n

Note that if you run via the CLI, you can add the --save-config flag to your command to save the configuration file in the experiments/test (or whatever the --experiment-name has been set to) directory. This allows for easy reproduction of an experiment at a later date or on someone else's computer through sharing the configuration file with them.

"},{"location":"getting_started/#setting-up-a-datasets-metadata","title":"Setting up a dataset's metadata","text":"

For each dataset you wish to work with, it is advisable to setup a corresponding metadata file. The package will infer this when information is missing (and you can then tweak it). The reason we suggest specifying metadata in this way is because Pandas / Python are in general bad at interpreting CSV files, particularly the specifics of datatypes, date objects and so on.

To do this, we must create a metadata yaml file in the dataset's directory. For example, for the support dataset, this file is located at data/support_metadata.yaml. By default, the package will look for a file with the same name as the dataset in the dataset's directory, but with _metadata appended to the end. This is configurable like most other filenaming conventions via the CLI.

The metadata file is split into two sections: columns and constraints. The former specifies the nature of each column in the dataset, whilst the latter specifies any constraints that should be enforced on the dataset.

"},{"location":"getting_started/#column-metadata","title":"Column metadata","text":"

Again, we refer to the support dataset's metadata file as an example:

columns:\n  dob:\n    dtype:\n      name: datetime64\n      floor: S\n  x1:\n    categorical: true\n    dtype: int64\n  x2:\n    categorical: true\n    dtype: int64\n  x3:\n    categorical: true\n  x4:\n    categorical: true\n    dtype: int64\n  x5:\n    categorical: true\n    dtype: int64\n  x6:\n    categorical: true\n    dtype: int64\n  x7:\n    dtype: int64\n  x8:\n    dtype: float64\n    missingness:\n      impute: mean\n  x9:\n    dtype: int64\n  x10:\n    dtype:\n      name: float64\n      rounding_scheme: 0.1\n  x11:\n    dtype: int64\n  x12:\n    dtype: float64\n  x13:\n    dtype: float64\n  x14:\n    dtype: float64\n  duration:\n    dtype: int64\n  event:\n    categorical: true\n    dtype: int64\n

For each column in the dataset, we specify the following:

  • It's dtype, this can be any numpy data type or a datetime type.
  • In the case of a datetime type, we also specify the floor (i.e. the smallest unit of time that we care about). In general this should be set to match the smallest unit of time in the dataset.
  • In the case of a float type, we can also specify a rounding_scheme to round the values to a certain number of decimal places, again this should be set according to the rounding applied to the column in the real data, or if you want to round the values for some other reason.
  • Whether it is categorical or not. If a column is not categorical, you don't need to specify this. A column is inferred as categorical if it has less than 10 unique values or is a string type.
  • If the column has missing values, we can specify how to deal with them by specifying a missingness strategy. In the case of the x8 column, we impute the missing values with the column's mean. If you don't specify this, the CLI or configuration file's specified global missingness strategy will be applied instead (this defaults to the augment strategy which model's the missingness as a separate level in the case of categorical features, or as a separate cluster in the case of continuous features).
"},{"location":"getting_started/#constraints","title":"Constraints","text":"

The second part of the metadata file specifies any constraints that should be enforced on the dataset. These can be a relative constraint between two columns, or a fixed one via a constant on a single column. For example, the support dataset's constraints are as follows (note that these are arbitrarily defined and do not necessarily reflect the real data):

constraints:\n  - \"x10 in (0,100)\"\n  - \"x12 in (0,100)\"\n  - \"x13 in (0,100)\"\n  - \"x10 <= x12\"\n  - \"x12 < x13\"\n  - \"x10 < x13\"\n  - \"x8 > x10\"\n  - \"x8 > x12\"\n  - \"x8 > x13\"\n  - \"x11 > 100\"\n  - \"x12 > 10\"\n

The function of these constraints is fairly self-explanatory: The package ensures the constraints are feasible and minimises them before applying transformations to ensure that they will be satisfied in the synthetic data as well. When a column does not meet a feasible constraint in the real data, we assume that this is intentional and use the violation as a feature upon which to generate synthetic data that also violates the constraint.

There is a further constraint fixcombo that only applies to categorical columns. This suggests that only existing combinations of two or more categorical columns should be generated, i.e. the columns can be collapsed into a single composite feature. I.e. if we have a column for pregnancy, and another for sex, we may only want to allow three categories, 'male:not-pregnant', 'female:pregnant', 'female:not-pregnant'. This is specified as follows:

constraints:\n  - \"pregnancy fixcombo sex\"\n

In conclusion then, we support the following constraint types:

  • fixcombo for categorical columns
  • < and < for non-categorical columns
  • >= and <= for non-categorical columns
  • in for non-categorical columns, which is effectively two of the above constraints combined. I.e. x in [a, b) is equivalent to x >= a and x < b. This is purely a UX feature and is treated as two separate constraints internally.

Once this metadata is setup, you are ready to run your experiment.

"},{"location":"getting_started/#evaluation","title":"Evaluation","text":"

Once models have been trained and synthetic datasets generated, we leverage evaluations from SDMetrics, Aequitas, the NHS' internal SynAdvSuite (at current time you must request access to this repository to use the privacy-related attacks it implements), and also offer a facility for the custom specification of downstream tasks. These evaluations are then aggregated into a dashboard for ease of comparison and analysis.

See the relevant documentation for each of these packages for more information on the metrics they offer.

"},{"location":"model_card/","title":"Model Card: Variational AutoEncoder with Differential Privacy","text":""},{"location":"model_card/#model-details","title":"Model Details","text":"

The implementation of the Variational AutoEncoder (VAE) with Differential Privacy within this repository is based on work done by Dominic Danks during an NHSX Analytics Unit PhD internship (last commit to the original SynthVAE repository: commit 88a4bdf). This model card describes an updated and extended version of the model, by Harrison Wilde. Further information about the previous version created by Dom and its model implementation can be found in Section 5.4 of the associated report.

"},{"location":"model_card/#model-use","title":"Model Use","text":""},{"location":"model_card/#intended-use","title":"Intended Use","text":"

This model is intended for use in experimenting with differential privacy and VAEs.

"},{"location":"model_card/#training-data","title":"Training Data","text":"

Experiments in this repository are run against the Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) dataset accessed via the pycox python library. We also performed further analysis on a single table that we extracted from MIMIC-III.

"},{"location":"model_card/#performance-and-limitations","title":"Performance and Limitations","text":"

A from-scratch VAE implementation was compared against various models available within the SDV framework using a variety of quality and privacy metrics on the SUPPORT dataset. The VAE was found to be competitive with all of these models across the various metrics. Differential Privacy (DP) was introduced via DP-SGD and the performance of the VAE for different levels of privacy was evaluated. It was found that as the level of Differential Privacy introduced by DP-SGD was increased, it became easier to distinguish between synthetic and real data.

Proper evaluation of quality and privacy of synthetic data is challenging. In this work, we utilised metrics from the SDV library due to their natural integration with the rest of the codebase. A valuable extension of this work would be to apply a variety of external metrics, including more advanced adversarial attacks to more thoroughly evaluate the privacy of the considered methods, including as the level of DP is varied. It would also be of interest to apply DP-SGD and/or PATE to all of the considered methods and evaluate whether the performance drop as a function of implemented privacy is similar or different across the models.

Currently the SynthVAE model only works for data which is 'clean'. I.e data that has no missingness or NaNs within its input. It can handle continuous, categorical and datetime variables. Special types such as nominal data cannot be handled properly however the model may still run. Column names have to be specified in the code for the variable group they belong to.

Hyperparameter tuning of the model can result in errors if certain parameter values are selected. Most commonly, changing learning rate in our example results in errors during training. An extensive test to evaluate plausible ranges has not been performed as of yet. If you get errors during tuning then consider your hyperparameter values and adjust accordingly.

"},{"location":"model_card/#acknowledgements","title":"Acknowledgements","text":"

This documentation is inspired by Model Cards for Model Reporting (Mitchell et al.) and Lessons from Archives (Jo & Gebru).

"},{"location":"models/","title":"Adding new models","text":"

The model module contains all of the architectures implemented as part of this package. We offer GAN and VAE based architectures with a number of adjustments to achieve privacy and other augmented functionalities. The module handles the training and generation of synthetic data using these architectures, per a user's choice of model(s) and configuration.

It is likely that as the literature matures, more effective architectures will present themselves as promising for application to the type of tabular data NHSSynth is designed for. Below we discuss how to add new models to the package.

"},{"location":"models/#model-design","title":"Model design","text":"

The models in this package are built entirely in PyTorch and use Opacus for differential privacy.

We have built the VAE and (Tabular)GAN implementations in this package to serve as the foundations for a number of other architectures. As such, we try to maintain a somewhat modular design to building up more complex differentially private (or otherwise augmented) architectures. Each model inherits from either the GAN or VAE class (in files of the same name) which in turn inherit from a generic Model class found in the common folder. This folder contains components of models which are not to be instantiated themselves, e.g. a mixin class for differential privacy, the MLP underlying the GAN and so on.

The Model class from which all of the models derive handles all of the general attributes. Roughly, these are the specifics of the dataset the instance of the model is relative to, the device that training is to be carried out upon, and other training parameters such as the total number of epochs to execute.

We define these things at the model level, as when using differential privacy or other privacy accountant methods, we must know ahead of time the data and length of training exposure in order to calculate the levels of noise required to reach a certain privacy guarantee and so on.

"},{"location":"models/#implementing-a-new-model","title":"Implementing a new model","text":"

In order to add a new architecture then, it is important to first investigate the modular parts already implemented to ensure that what you want to build is not already possible through the composition of these existing parts. Then you must ensure that your architecture either inherits from the GAN or VAE, or Model if you wish to implement a different type of generative model.

In all of these cases, the interface expects for the implementation to have the following methods:

  • get_args: a class method that lists the architecture specific arguments that the model requires. This is used to facilitate default arguments in the python API whilst still allowing for arguments in the CLI to be propagated and recorded automatically in the experiment output. This should be a list of variable names equal to the concatenation of all of the non-Model parent classes (e.g. DPVAE has DP and VAE args) plus any architecture specific arguments in the __init__ method of the model in question.
  • get_metrics: another class method that behaves similarly to the above, should return a list of valid metrics to track during training for this model
  • train: a method handling the training loop for the model. This should take num_epochs, patience and displayed_metrics as arguments and return a tuple containing the number of epochs that were executed plus a bundle of training metrics (the values over time returned by get_metrics on the class). In the execution of this method, the utility methods defined in Model should be called in order, _start_training at the beginning, then _record_metrics at each training step of the data loader, and finally _finish_training to clean up progress bars and so on. displayed_metrics determines which metrics are actively displayed during training.
  • generate: a method to call on the trained model which generates N samples of data, and calls the model's associated MetaTransformer to return a valid pandas DataFrame of synthetic data ready to output.
"},{"location":"models/#adding-a-new-model-to-the-cli","title":"Adding a new model to the CLI","text":"

Once you have implemented your new model, you must add it to the CLI. To do this, we must first export the model's class into the MODELS constant in the __init__ file in the models subfolder. We can then add a new function and option in module_arguments.py to list the arguments and their types unique to this type of architecture.

Note that you should not duplicate arguments that are already defined in the Model class or foundational model architectures such as the GAN if you are implementing an extension to it. If you have setup get_args correctly all of this will be propagated automatically.

"},{"location":"modules/","title":"Adding new modules","text":"

The package is designed such that each module can be used as part of a pipeline (via the CLI or a configuration file) or independently (via importing them into an existing codebase).

In the future it may be desireable to add / adjust the modules of the package, this guide offers a high-level overview of how to do so.

"},{"location":"modules/#importing-a-module-from-this-package","title":"Importing a module from this package","text":"

After installing the package, you can simply do:

from nhssynth.modules import <module>\n
and you will be able to use it in your code!

"},{"location":"modules/#creating-a-new-module-and-folding-it-into-the-cli","title":"Creating a new module and folding it into the CLI","text":"

The following instructions specify how to extend this package with a new module:

  1. Create a folder for your module within the package, i.e. src/nhssynth/modules/mymodule
  2. Include within it a main executor function that accepts arguments from the CLI, i.e.

    def myexecutor(args):\n    ...\n

    In mymodule/executor.py and export it by adding from .executor import myexecutor to mymodule/__init__.py. Check the existing modules for examples of what a typical executor function looks like.

  3. In the cli folder, add a corresponding function to module_arguments.py and populate with arguments you want to expose in the CLI:

    def add_mymodule_args(parser: argparse.ArgumentParser, group_title: str, overrides=False):\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(...)\n    group.add_argument(...)\n    ...\n
  4. Next, in module_setup.py make the following adjustments to the MODULE_MAP code:

    MODULE_MAP = {\n    ...\n    \"mymodule\": ModuleConfig(\n        func=m.mymodule.myexecutor,\n        add_args=ma.add_mymodule_args,\n        description=\"...\",\n        help=\"...\",\n        common_parsers=[...]\n    ),\n    ...\n}\n

    Where common_parsers is a subset of COMMON_PARSERS defined in common_arguments.py. Note that the \"seed\" and \"core\" parsers are added automatically, so you don't need to specify them. These parsers can be used to add arguments to your module that are common to multiple modules, e.g. the dataloader and evaluation modules both use --typed to specify the path of the typed input dataset.

  5. You can (optionally) also edit the following block if you want your module to be included in a full pipeline run:

    PIPELINE = [..., mymodule, ...]  # NOTE this determines the order of a pipeline run\n
  6. Congrats, your module is implemented within the CLI, its documentation etc. will now be built automatically and it can be referenced in configuration files!

"},{"location":"secure_mode/","title":"Opacus' secure mode","text":"

Part of the process for achieving a differential privacy guarantee under Opacus involves generating noise according to a Gaussian distribution with mean 0 in Opacus' _generate_noise() function.

Enabling secure_mode when using the NHSSynth package ensures that the generated noise is also secure against floating point representation attacks, such as the ones in https://arxiv.org/abs/2107.10138 and https://arxiv.org/abs/2112.05307.

This attack first appeared in https://arxiv.org/abs/2112.05307; the fix via the csprng package is based on https://arxiv.org/abs/2107.10138 and involves calling the Gaussian noise function $2n$ times, where $n=2$ (see section 5.1 in https://arxiv.org/abs/2107.10138).

The reason for choosing $n=2$ is that $n$ can be any number greater than $1$. The bigger $n$ is, though, the more computation needs to be done to generate the Gaussian samples. The choice of $n=2$ is justified via the knowledge that the attack has a complexity of $2^{p(2n-1)}$. In PyTorch, $p=53$ and so the complexity is $2^159$, which is deemed sufficiently hard for an attacker to break.

"},{"location":"reference/SUMMARY/","title":"SUMMARY","text":"
  • cli
    • common_arguments
    • config
    • model_arguments
    • module_arguments
    • module_setup
    • run
  • common
    • common
    • constants
    • debugging
    • dicts
    • io
    • strings
  • modules
    • dashboard
      • Upload
      • io
      • pages
        • 1_Tables
        • 2_Plots
        • 3_Experiment_Configurations
      • run
      • utils
    • dataloader
      • constraints
      • io
      • metadata
      • metatransformer
      • missingness
      • run
      • transformers
        • base
        • categorical
        • continuous
        • datetime
    • evaluation
      • aequitas
      • io
      • metrics
      • run
      • tasks
      • utils
    • model
      • common
        • dp
        • mlp
        • model
      • io
      • models
        • dpvae
        • gan
        • vae
      • run
      • utils
    • plotting
      • io
      • plots
      • run
    • structure
      • run
"},{"location":"reference/cli/","title":"cli","text":""},{"location":"reference/cli/common_arguments/","title":"common_arguments","text":"

Functions to define the CLI's \"common\" arguments, i.e. those that can be applied to either: - All module argument lists, e.g. --dataset, --seed, etc. - A subset of module(s) argument lists, e.g. --synthetic, --typed, etc.

"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.get_core_parser","title":"get_core_parser(overrides=False)","text":"

Create the core common parser group applied to all modules (and the pipeline and config options). Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.

Parameters:

Name Type Description Default overrides

whether the arguments declared within are required or not.

False

Returns:

Type Description ArgumentParser

The parser with the group containing the core arguments attached.

Source code in src/nhssynth/cli/common_arguments.py
def get_core_parser(overrides=False) -> argparse.ArgumentParser:\n    \"\"\"\n    Create the core common parser group applied to all modules (and the `pipeline` and `config` options).\n    Note that we leverage common titling of the argument group to ensure arguments appear together even if declared separately.\n\n    Args:\n        overrides: whether the arguments declared within are required or not.\n\n    Returns:\n        The parser with the group containing the core arguments attached.\n    \"\"\"\n    \"\"\"\"\"\"\n    core = argparse.ArgumentParser(add_help=False)\n    core_grp = core.add_argument_group(title=\"options\")\n    core_grp.add_argument(\n        \"-d\",\n        \"--dataset\",\n        required=(not overrides),\n        type=str,\n        help=\"the name of the dataset to experiment with, should be present in `<DATA_DIR>`\",\n    )\n    core_grp.add_argument(\n        \"-e\",\n        \"--experiment-name\",\n        type=str,\n        default=TIME,\n        help=\"name the experiment run to affect logging, config, and default-behaviour i/o\",\n    )\n    core_grp.add_argument(\n        \"--save-config\",\n        action=\"store_true\",\n        help=\"save the config provided via the cli, this is a recommended option for reproducibility\",\n    )\n    return core\n
"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.get_seed_parser","title":"get_seed_parser(overrides=False)","text":"

Create the common parser group for the seed. NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.

Parameters:

Name Type Description Default overrides

whether the arguments declared within are required or not.

False

Returns:

Type Description ArgumentParser

The parser with the group containing the seed argument attached.

Source code in src/nhssynth/cli/common_arguments.py
def get_seed_parser(overrides=False) -> argparse.ArgumentParser:\n    \"\"\"\n    Create the common parser group for the seed.\n    NB This is separate to the rest of the core arguments as it does not apply to the dashboard module.\n\n    Args:\n        overrides: whether the arguments declared within are required or not.\n\n    Returns:\n        The parser with the group containing the seed argument attached.\n    \"\"\"\n    parser = argparse.ArgumentParser(add_help=False)\n    parser_grp = parser.add_argument_group(title=\"options\")\n    parser_grp.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        help=\"specify a seed for reproducibility, this is a recommended option for reproducibility\",\n    )\n    return parser\n
"},{"location":"reference/cli/common_arguments/#nhssynth.cli.common_arguments.suffix_parser_generator","title":"suffix_parser_generator(name, help, required=False)","text":"

Generator function for creating parsers following a common template. These parsers are all suffixes to the --dataset / -d / DATASET argument, see COMMON_TITLE.

Parameters:

Name Type Description Default name str

the name / label of the argument to add to the CLI options.

required help str

the help message when the CLI is run with --help / -h.

required required bool

whether the argument must be provided or not.

False Source code in src/nhssynth/cli/common_arguments.py
def suffix_parser_generator(name: str, help: str, required: bool = False) -> argparse.ArgumentParser:\n    \"\"\"Generator function for creating parsers following a common template.\n    These parsers are all suffixes to the --dataset / -d / DATASET argument, see `COMMON_TITLE`.\n\n    Args:\n        name: the name / label of the argument to add to the CLI options.\n        help: the help message when the CLI is run with --help / -h.\n        required: whether the argument must be provided or not.\n    \"\"\"\n\n    def get_parser(overrides: bool = False) -> argparse.ArgumentParser:\n        parser = argparse.ArgumentParser(add_help=False)\n        parser_grp = parser.add_argument_group(title=COMMON_TITLE)\n        parser_grp.add_argument(\n            f\"--{name.replace('_', '-')}\",\n            required=required and not overrides,\n            type=str,\n            default=f\"_{name}\",\n            help=help,\n        )\n        return parser\n\n    return get_parser\n
"},{"location":"reference/cli/config/","title":"config","text":"

Read, write and process config files, including handling of module-specific / common config overrides.

"},{"location":"reference/cli/config/#nhssynth.cli.config.assemble_config","title":"assemble_config(args, all_subparsers)","text":"

Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.

Parameters:

Name Type Description Default args Namespace

A namespace object containing all parsed command-line arguments.

required all_subparsers dict[str, ArgumentParser]

A dictionary mapping module names to subparser objects.

required

Returns:

Type Description dict[str, Any]

A dictionary containing configuration information extracted from args in a module-wise nested format that is YAML-friendly.

Raises:

Type Description ValueError

If a module specified in args.modules_to_run is not in all_subparsers.

Source code in src/nhssynth/cli/config.py
def assemble_config(\n    args: argparse.Namespace,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> dict[str, Any]:\n    \"\"\"\n    Assemble and arrange a nested-via-module configuration dictionary from parsed command-line arguments to be output as a YAML record.\n\n    Args:\n        args: A namespace object containing all parsed command-line arguments.\n        all_subparsers: A dictionary mapping module names to subparser objects.\n\n    Returns:\n        A dictionary containing configuration information extracted from `args` in a module-wise nested format that is YAML-friendly.\n\n    Raises:\n        ValueError: If a module specified in `args.modules_to_run` is not in `all_subparsers`.\n    \"\"\"\n    args_dict = vars(args)\n\n    # Filter out the keys that are not relevant to the config file\n    args_dict = filter_dict(\n        args_dict, {\"func\", \"experiment_name\", \"save_config\", \"save_config_path\", \"module_handover\"}\n    )\n    for k in args_dict.copy().keys():\n        # Remove empty metric lists from the config\n        if \"_metrics\" in k and not args_dict[k]:\n            args_dict.pop(k)\n\n    modules_to_run = args_dict.pop(\"modules_to_run\")\n    if len(modules_to_run) == 1:\n        run_type = modules_to_run[0]\n    elif modules_to_run == PIPELINE:\n        run_type = \"pipeline\"\n    else:\n        raise ValueError(f\"Invalid value for `modules_to_run`: {modules_to_run}\")\n\n    # Generate a dictionary containing each module's name from the run, with all of its possible corresponding config args\n    module_args = {\n        module_name: [action.dest for action in all_subparsers[module_name]._actions if action.dest != \"help\"]\n        for module_name in modules_to_run\n    }\n\n    # Use the flat namespace to populate a nested (by module) dictionary of config args and values\n    out_dict = {}\n    for module_name in modules_to_run:\n        for k in args_dict.copy().keys():\n            # We want to keep dataset, experiment_name, seed and save_config at the top-level as they are core args\n            if k in module_args[module_name] and k not in {\n                \"version\",\n                \"dataset\",\n                \"experiment_name\",\n                \"seed\",\n                \"save_config\",\n            }:\n                if module_name not in out_dict:\n                    out_dict[module_name] = {}\n                v = args_dict.pop(k)\n                if v is not None:\n                    out_dict[module_name][k] = v\n\n    # Assemble the final dictionary in YAML-compliant form\n    return {**({\"run_type\": run_type} if run_type else {}), **args_dict, **out_dict}\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.get_default_and_required_args","title":"get_default_and_required_args(top_parser, module_parsers)","text":"

Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.

Parameters:

Name Type Description Default top_parser ArgumentParser

The top-level parser (contains common arguments).

required module_parsers dict[str, ArgumentParser]

The dict of module-level parsers mapped to their names.

required

Returns:

Type Description tuple[dict[str, Any], list[str]]

A tuple containing two elements: - A dictionary containing all arguments and their default values. - A list of key-value-pairs of the required arguments and their associated module.

Source code in src/nhssynth/cli/config.py
def get_default_and_required_args(\n    top_parser: argparse.ArgumentParser,\n    module_parsers: dict[str, argparse.ArgumentParser],\n) -> tuple[dict[str, Any], list[str]]:\n    \"\"\"\n    Get the default and required arguments for the top-level parser and the current run's corresponding list of module parsers.\n\n    Args:\n        top_parser: The top-level parser (contains common arguments).\n        module_parsers: The dict of module-level parsers mapped to their names.\n\n    Returns:\n        A tuple containing two elements:\n            - A dictionary containing all arguments and their default values.\n            - A list of key-value-pairs of the required arguments and their associated module.\n    \"\"\"\n    all_actions = {\"top-level\": top_parser._actions} | {m: p._actions for m, p in module_parsers.items()}\n    defaults = {}\n    required_args = []\n    for module, actions in all_actions.items():\n        for action in actions:\n            if action.dest not in [\"help\", \"==SUPPRESS==\"]:\n                defaults[action.dest] = action.default\n                if action.required:\n                    required_args.append({\"arg\": action.dest, \"module\": module})\n    return defaults, required_args\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.get_modules_to_run","title":"get_modules_to_run(executor)","text":"

Get the list of modules to run from the passed executor function.

Parameters:

Name Type Description Default executor Callable

The executor function to run.

required

Returns:

Type Description list[str]

A list of module names to run.

Source code in src/nhssynth/cli/config.py
def get_modules_to_run(executor: Callable) -> list[str]:\n    \"\"\"\n    Get the list of modules to run from the passed executor function.\n\n    Args:\n        executor: The executor function to run.\n\n    Returns:\n        A list of module names to run.\n    \"\"\"\n    if executor == run_pipeline:\n        return PIPELINE\n    else:\n        return [get_key_by_value({mn: mc.func for mn, mc in MODULE_MAP.items()}, executor)]\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.read_config","title":"read_config(args, parser, all_subparsers)","text":"

Hierarchically assembles a config argparse.Namespace object for the inferred modules to run and execute, given a file.

  1. Load the YAML file containing the config to read from
  2. Check a valid run_type is specified or infer it and determine the list of modules_to_run
  3. Establish the appropriate default configuration set of arguments from the parser and all_subparsers for the determined modules_to_run
  4. Overwrite these with the specified (sub)set of config in the YAML file
  5. Overwrite again with passed command-line args (these are considered 'overrides')
  6. Run the appropriate module(s) or pipeline with the resulting configuration Namespace object

Parameters:

Name Type Description Default args Namespace

Namespace object containing arguments from the command line

required parser ArgumentParser

top-level ArgumentParser object containing common arguments

required all_subparsers dict[str, ArgumentParser]

dictionary of ArgumentParser objects, one for each module

required

Returns:

Type Description Namespace

A Namespace object containing the assembled configuration settings

Raises:

Type Description AssertionError

if any required arguments are missing from the configuration file / overrides

Source code in src/nhssynth/cli/config.py
def read_config(\n    args: argparse.Namespace,\n    parser: argparse.ArgumentParser,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> argparse.Namespace:\n    \"\"\"\n    Hierarchically assembles a config `argparse.Namespace` object for the inferred modules to run and execute, given a file.\n\n    1. Load the YAML file containing the config to read from\n    2. Check a valid `run_type` is specified or infer it and determine the list of `modules_to_run`\n    3. Establish the appropriate default configuration set of arguments from the `parser` and `all_subparsers` for the determined `modules_to_run`\n    4. Overwrite these with the specified (sub)set of config in the YAML file\n    5. Overwrite again with passed command-line `args` (these are considered 'overrides')\n    6. Run the appropriate module(s) or pipeline with the resulting configuration `Namespace` object\n\n    Args:\n        args: Namespace object containing arguments from the command line\n        parser: top-level `ArgumentParser` object containing common arguments\n        all_subparsers: dictionary of `ArgumentParser` objects, one for each module\n\n    Returns:\n        A Namespace object containing the assembled configuration settings\n\n    Raises:\n        AssertionError: if any required arguments are missing from the configuration file / overrides\n    \"\"\"\n    # Open the passed yaml file and load into a dictionary\n    with open(f\"config/{args.input_config}.yaml\") as stream:\n        config_dict = yaml.safe_load(stream)\n\n    valid_run_types = [x for x in all_subparsers.keys() if x != \"config\"]\n\n    version = config_dict.pop(\"version\", None)\n    if version and version != version(\"nhssynth\"):\n        warnings.warn(\n            f\"This config file's specified version ({version}) does not match the currently installed version of nhssynth ({version('nhssynth')}), results may differ.\"\n        )\n    elif not version:\n        version = ver(\"nhssynth\")\n\n    run_type = config_dict.pop(\"run_type\", None)\n\n    if run_type == \"pipeline\":\n        modules_to_run = PIPELINE\n    else:\n        modules_to_run = [x for x in config_dict.keys() | {run_type} if x in valid_run_types]\n        if not args.custom_pipeline:\n            modules_to_run = sorted(modules_to_run, key=lambda x: PIPELINE.index(x))\n\n    if not modules_to_run:\n        warnings.warn(\n            \"Missing or invalid `run_type` and / or module specification hierarchy in `config/{args.input_config}.yaml`, defaulting to a full run of the pipeline\"\n        )\n        modules_to_run = PIPELINE\n\n    # Get all possible default arguments by scraping the top level `parser` and the appropriate sub-parser for the `run_type`\n    args_dict, required_args = get_default_and_required_args(\n        parser, filter_dict(all_subparsers, modules_to_run, include=True)\n    )\n\n    # Find the non-default arguments amongst passed `args` by seeing which of them are different to the entries of `args_dict`\n    non_default_passed_args_dict = {\n        k: v\n        for k, v in vars(args).items()\n        if k in [\"input_config\", \"custom_pipeline\"] or (k in args_dict and k != \"func\" and v != args_dict[k])\n    }\n\n    # Overwrite the default arguments with the ones from the yaml file\n    args_dict.update(flatten_dict(config_dict))\n\n    # Overwrite the result of the above with any non-default CLI args\n    args_dict.update(non_default_passed_args_dict)\n\n    # Create a new Namespace using the assembled dictionary\n    new_args = argparse.Namespace(**args_dict)\n    assert getattr(\n        new_args, \"dataset\"\n    ), \"No dataset specified in the passed config file, provide one with the `--dataset` argument or add it to the config file\"\n    assert all(\n        getattr(new_args, req_arg[\"arg\"]) for req_arg in required_args\n    ), f\"Required arguments are missing from the passed config file: {[ra['module'] + ':' + ra['arg'] for ra in required_args if not getattr(new_args, ra['arg'])]}\"\n\n    # Run the appropriate execution function(s)\n    if not new_args.seed:\n        warnings.warn(\"No seed has been specified, meaning the results of this run may not be reproducible.\")\n    new_args.version = version\n    new_args.modules_to_run = modules_to_run\n    new_args.module_handover = {}\n    for module in new_args.modules_to_run:\n        MODULE_MAP[module](new_args)\n\n    return new_args\n
"},{"location":"reference/cli/config/#nhssynth.cli.config.write_config","title":"write_config(args, all_subparsers)","text":"

Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by args.save_config_path.

Parameters:

Name Type Description Default args Namespace

A namespace containing the run's configuration.

required all_subparsers dict[str, ArgumentParser]

A dictionary containing all subparsers for the config args.

required Source code in src/nhssynth/cli/config.py
def write_config(\n    args: argparse.Namespace,\n    all_subparsers: dict[str, argparse.ArgumentParser],\n) -> None:\n    \"\"\"\n    Assembles a configuration dictionary from the run config and writes it to a YAML file at the location specified by `args.save_config_path`.\n\n    Args:\n        args: A namespace containing the run's configuration.\n        all_subparsers: A dictionary containing all subparsers for the config args.\n    \"\"\"\n    experiment_name = args.experiment_name\n    args_dict = assemble_config(args, all_subparsers)\n    with open(f\"experiments/{experiment_name}/config_{experiment_name}.yaml\", \"w\") as yaml_file:\n        yaml.dump(args_dict, yaml_file, default_flow_style=False, sort_keys=False)\n
"},{"location":"reference/cli/model_arguments/","title":"model_arguments","text":"

Define arguments for each of the model classes.

"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_gan_args","title":"add_gan_args(group, overrides=False)","text":"

Adds arguments to an existing group for the GAN model.

Source code in src/nhssynth/cli/model_arguments.py
def add_gan_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group for the GAN model.\"\"\"\n    group.add_argument(\n        \"--n-units-conditional\",\n        type=int,\n        help=\"the number of units in the conditional layer\",\n    )\n    group.add_argument(\n        \"--generator-n-layers-hidden\",\n        type=int,\n        help=\"the number of hidden layers in the generator\",\n    )\n    group.add_argument(\n        \"--generator-n-units-hidden\",\n        type=int,\n        help=\"the number of units in each hidden layer of the generator\",\n    )\n    group.add_argument(\n        \"--generator-activation\",\n        type=str,\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the generator\",\n    )\n    group.add_argument(\n        \"--generator-batch-norm\",\n        action=\"store_true\",\n        help=\"whether to use batch norm in the generator\",\n    )\n    group.add_argument(\n        \"--generator-dropout\",\n        type=float,\n        help=\"the dropout rate in the generator\",\n    )\n    group.add_argument(\n        \"--generator-lr\",\n        type=float,\n        help=\"the learning rate for the generator\",\n    )\n    group.add_argument(\n        \"--generator-residual\",\n        action=\"store_true\",\n        help=\"whether to use residual connections in the generator\",\n    )\n    group.add_argument(\n        \"--generator-opt-betas\",\n        type=float,\n        nargs=2,\n        help=\"the beta values for the generator optimizer\",\n    )\n    group.add_argument(\n        \"--discriminator-n-layers-hidden\",\n        type=int,\n        help=\"the number of hidden layers in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-n-units-hidden\",\n        type=int,\n        help=\"the number of units in each hidden layer of the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-activation\",\n        type=str,\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-batch-norm\",\n        action=\"store_true\",\n        help=\"whether to use batch norm in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-dropout\",\n        type=float,\n        help=\"the dropout rate in the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-lr\",\n        type=float,\n        help=\"the learning rate for the discriminator\",\n    )\n    group.add_argument(\n        \"--discriminator-opt-betas\",\n        type=float,\n        nargs=2,\n        help=\"the beta values for the discriminator optimizer\",\n    )\n    group.add_argument(\n        \"--clipping-value\",\n        type=float,\n        help=\"the clipping value for the discriminator\",\n    )\n    group.add_argument(\n        \"--lambda-gradient-penalty\",\n        type=float,\n        help=\"the gradient penalty coefficient\",\n    )\n
"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_model_specific_args","title":"add_model_specific_args(group, name, overrides=False)","text":"

Adds arguments to an existing group according to name.

Source code in src/nhssynth/cli/model_arguments.py
def add_model_specific_args(group: argparse._ArgumentGroup, name: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group according to `name`.\"\"\"\n    if name == \"VAE\":\n        add_vae_args(group, overrides)\n    elif name == \"GAN\":\n        add_gan_args(group, overrides)\n    elif name == \"TabularGAN\":\n        add_tabular_gan_args(group, overrides)\n
"},{"location":"reference/cli/model_arguments/#nhssynth.cli.model_arguments.add_vae_args","title":"add_vae_args(group, overrides=False)","text":"

Adds arguments to an existing group for the VAE model.

Source code in src/nhssynth/cli/model_arguments.py
def add_vae_args(group: argparse._ArgumentGroup, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing group for the VAE model.\"\"\"\n    group.add_argument(\n        \"--encoder-latent-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the latent dimension of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-hidden-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the hidden dimension of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-activation\",\n        type=str,\n        nargs=\"+\",\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the encoder\",\n    )\n    group.add_argument(\n        \"--encoder-learning-rate\",\n        type=float,\n        nargs=\"+\",\n        help=\"the learning rate for the encoder\",\n    )\n    group.add_argument(\n        \"--decoder-latent-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the latent dimension of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-hidden-dim\",\n        type=int,\n        nargs=\"+\",\n        help=\"the hidden dimension of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-activation\",\n        type=str,\n        nargs=\"+\",\n        choices=list(ACTIVATION_FUNCTIONS.keys()),\n        help=\"the activation function of the decoder\",\n    )\n    group.add_argument(\n        \"--decoder-learning-rate\",\n        type=float,\n        nargs=\"+\",\n        help=\"the learning rate for the decoder\",\n    )\n    group.add_argument(\n        \"--shared-optimizer\",\n        action=\"store_true\",\n        help=\"whether to use a shared optimizer for the encoder and decoder\",\n    )\n
"},{"location":"reference/cli/module_arguments/","title":"module_arguments","text":"

Define arguments for each of the modules' CLI sub-parsers.

"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.AllChoicesDefault","title":"AllChoicesDefault","text":"

Bases: Action

Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied: (i.e. user passes --metrics with no follow up list of metric groups => all metric groups will be executed).

Notes

1) If no option_string is supplied: set to default value (self.default) 2) If option_string is supplied: a) If values are supplied, set to list of values b) If no values are supplied, set to self.const, if self.const is not set, set to self.default

Source code in src/nhssynth/cli/module_arguments.py
class AllChoicesDefault(argparse.Action):\n    \"\"\"\n    Customised argparse action for defaulting to the full list of choices if only the argument's flag is supplied:\n    (i.e. user passes `--metrics` with no follow up list of metric groups => all metric groups will be executed).\n\n    Notes:\n        1) If no `option_string` is supplied: set to default value (`self.default`)\n        2) If `option_string` is supplied:\n            a) If `values` are supplied, set to list of values\n            b) If no `values` are supplied, set to `self.const`, if `self.const` is not set, set to `self.default`\n    \"\"\"\n\n    def __call__(self, parser, namespace, values=None, option_string=None):\n        if values:\n            setattr(namespace, self.dest, values)\n        elif option_string:\n            setattr(namespace, self.dest, self.const if self.const else self.default)\n        else:\n            setattr(namespace, self.dest, self.default)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_dataloader_args","title":"add_dataloader_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing dataloader module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_dataloader_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing dataloader module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--data-dir\",\n        type=str,\n        default=\"./data\",\n        help=\"the directory containing the chosen dataset\",\n    )\n    group.add_argument(\n        \"--index-col\",\n        default=None,\n        nargs=\"*\",\n        help=\"indicate the name of the index column(s) in the csv file, such that pandas can index by it\",\n    )\n    group.add_argument(\n        \"--constraint-graph\",\n        type=str,\n        default=\"_constraint_graph\",\n        help=\"the name of the html file to write the constraint graph to, defaults to `<DATASET>_constraint_graph`\",\n    )\n    group.add_argument(\n        \"--collapse-yaml\",\n        action=\"store_true\",\n        help=\"use aliases and anchors in the output metadata yaml, this will make it much more compact\",\n    )\n    group.add_argument(\n        \"--missingness\",\n        type=str,\n        default=\"augment\",\n        choices=MISSINGNESS_STRATEGIES,\n        help=\"how to handle missing values in the dataset\",\n    )\n    group.add_argument(\n        \"--impute\",\n        type=str,\n        default=None,\n        help=\"the imputation strategy to use, ONLY USED if <MISSINGNESS> is set to 'impute', choose from: 'mean', 'median', 'mode', or any specific value (e.g. '0')\",\n    )\n    group.add_argument(\n        \"--write-csv\",\n        action=\"store_true\",\n        help=\"write the transformed real data to a csv file\",\n    )\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_evaluation_args","title":"add_evaluation_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing evaluation module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_evaluation_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing evaluation module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--downstream-tasks\",\n        \"--tasks\",\n        action=\"store_true\",\n        help=\"run the downstream tasks evaluation\",\n    )\n    group.add_argument(\n        \"--tasks-dir\",\n        type=str,\n        default=\"./tasks\",\n        help=\"the directory containing the downstream tasks to run, this directory must contain a folder called <DATASET> containing the tasks to run\",\n    )\n    group.add_argument(\n        \"--aequitas\",\n        action=\"store_true\",\n        help=\"run the aequitas fairness evaluation (note this runs for each of the downstream tasks)\",\n    )\n    group.add_argument(\n        \"--aequitas-attributes\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the attributes to use for the aequitas fairness evaluation, defaults to all attributes\",\n    )\n    group.add_argument(\n        \"--key-numerical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the numerical key field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--sensitive-numerical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the numerical sensitive field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--key-categorical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the categorical key field attributes to use for SDV privacy evaluations\",\n    )\n    group.add_argument(\n        \"--sensitive-categorical-fields\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"the categorical sensitive field attributes to use for SDV privacy evaluations\",\n    )\n    for name in METRIC_CHOICES:\n        generate_evaluation_arg(group, name)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_model_args","title":"add_model_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing model module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing model module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--architecture\",\n        type=str,\n        nargs=\"+\",\n        default=[\"VAE\"],\n        choices=MODELS,\n        help=\"the model architecture(s) to train\",\n    )\n    group.add_argument(\n        \"--repeats\",\n        type=int,\n        default=1,\n        help=\"how many times to repeat the training process per model architecture (<SEED> is incremented each time)\",\n    )\n    group.add_argument(\n        \"--batch-size\",\n        type=int,\n        nargs=\"+\",\n        default=32,\n        help=\"the batch size for the model\",\n    )\n    group.add_argument(\n        \"--num-epochs\",\n        type=int,\n        nargs=\"+\",\n        default=100,\n        help=\"number of epochs to train for\",\n    )\n    group.add_argument(\n        \"--patience\",\n        type=int,\n        nargs=\"+\",\n        default=5,\n        help=\"how many epochs the model is allowed to train for without improvement\",\n    )\n    group.add_argument(\n        \"--displayed-metrics\",\n        type=str,\n        nargs=\"+\",\n        default=[],\n        help=\"metrics to display during training of the model, when set to `None`, all metrics are displayed\",\n    )\n    group.add_argument(\n        \"--use-gpu\",\n        action=\"store_true\",\n        help=\"use the GPU for training\",\n    )\n    group.add_argument(\n        \"--num-samples\",\n        type=int,\n        default=None,\n        help=\"the number of samples to generate from the model, defaults to the size of the original dataset\",\n    )\n    privacy_group = parser.add_argument_group(title=\"model privacy options\")\n    privacy_group.add_argument(\n        \"--target-epsilon\",\n        type=float,\n        nargs=\"+\",\n        default=1.0,\n        help=\"the target epsilon for differential privacy\",\n    )\n    privacy_group.add_argument(\n        \"--target-delta\",\n        type=float,\n        nargs=\"+\",\n        help=\"the target delta for differential privacy, defaults to `1 / len(dataset)` if not specified\",\n    )\n    privacy_group.add_argument(\n        \"--max-grad-norm\",\n        type=float,\n        nargs=\"+\",\n        default=5.0,\n        help=\"the clipping threshold for gradients (only relevant under differential privacy)\",\n    )\n    privacy_group.add_argument(\n        \"--secure-mode\",\n        action=\"store_true\",\n        help=\"Enable secure RNG via the `csprng` package to make privacy guarantees more robust, comes at a cost of performance and reproducibility\",\n    )\n    for model_name in MODELS.keys():\n        model_group = parser.add_argument_group(title=f\"{model_name}-specific options\")\n        add_model_specific_args(model_group, model_name, overrides=overrides)\n
"},{"location":"reference/cli/module_arguments/#nhssynth.cli.module_arguments.add_plotting_args","title":"add_plotting_args(parser, group_title, overrides=False)","text":"

Adds arguments to an existing plotting module sub-parser instance.

Source code in src/nhssynth/cli/module_arguments.py
def add_plotting_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False) -> None:\n    \"\"\"Adds arguments to an existing plotting module sub-parser instance.\"\"\"\n    group = parser.add_argument_group(title=group_title)\n    group.add_argument(\n        \"--plot-quality\",\n        action=\"store_true\",\n        help=\"plot the SDV quality report\",\n    )\n    group.add_argument(\n        \"--plot-diagnostic\",\n        action=\"store_true\",\n        help=\"plot the SDV diagnostic report\",\n    )\n    group.add_argument(\n        \"--plot-sdv-report\",\n        action=\"store_true\",\n        help=\"plot the SDV report\",\n    )\n    group.add_argument(\n        \"--plot-tsne\",\n        action=\"store_true\",\n        help=\"plot the t-SNE embeddings of the real and synthetic data\",\n    )\n
"},{"location":"reference/cli/module_setup/","title":"module_setup","text":"

Specify all CLI-accessible modules and their configurations, the pipeline to run by default, and define special functions for the config and pipeline CLI option trees.

"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.ModuleConfig","title":"ModuleConfig","text":"

Represents a module's configuration, containing the following attributes:

Attributes:

Name Type Description func

A callable that executes the module's functionality.

add_args

A callable that populates the module's sub-parser arguments.

description

A description of the module's functionality.

help

A help message for the module's command-line interface.

common_parsers

A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.

Source code in src/nhssynth/cli/module_setup.py
class ModuleConfig:\n    \"\"\"\n    Represents a module's configuration, containing the following attributes:\n\n    Attributes:\n        func: A callable that executes the module's functionality.\n        add_args: A callable that populates the module's sub-parser arguments.\n        description: A description of the module's functionality.\n        help: A help message for the module's command-line interface.\n        common_parsers: A list of common parsers to add to the module's sub-parser, appending the 'dataset' and 'core' parsers to those passed.\n    \"\"\"\n\n    def __init__(\n        self,\n        func: Callable[..., argparse.Namespace],\n        add_args: Callable[..., None],\n        description: str,\n        help: str,\n        common_parsers: Optional[list[str]] = None,\n        no_seed: bool = False,\n    ) -> None:\n        self.func = func\n        self.add_args = add_args\n        self.description = description\n        self.help = help\n        self.common_parsers = [\"core\", \"seed\"] if not no_seed else [\"core\"]\n        if common_parsers:\n            assert set(common_parsers) <= COMMON_PARSERS.keys(), \"Invalid common parser(s) specified.\"\n            # merge the below two assert statements\n            assert (\n                \"core\" not in common_parsers and \"seed\" not in common_parsers\n            ), \"The 'seed' and 'core' parser groups are automatically added to all modules, remove the from `ModuleConfig`s.\"\n            self.common_parsers += common_parsers\n\n    def __call__(self, args: argparse.Namespace) -> argparse.Namespace:\n        return self.func(args)\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_config_args","title":"add_config_args(parser)","text":"

Adds arguments to parser relating to configuration file handling and module-specific config overrides.

Source code in src/nhssynth/cli/module_setup.py
def add_config_args(parser: argparse.ArgumentParser) -> None:\n    \"\"\"Adds arguments to `parser` relating to configuration file handling and module-specific config overrides.\"\"\"\n    parser.add_argument(\n        \"-c\",\n        \"--input-config\",\n        required=True,\n        help=\"specify the config file name\",\n    )\n    parser.add_argument(\n        \"-cp\",\n        \"--custom-pipeline\",\n        action=\"store_true\",\n        help=\"infer a custom pipeline running order of modules from the config\",\n    )\n    for module_name in PIPELINE:\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} option overrides\", overrides=True)\n    for module_name in VALID_MODULES - set(PIPELINE):\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} options overrides\", overrides=True)\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_pipeline_args","title":"add_pipeline_args(parser)","text":"

Adds arguments to parser for each module in the pipeline.

Source code in src/nhssynth/cli/module_setup.py
def add_pipeline_args(parser: argparse.ArgumentParser) -> None:\n    \"\"\"Adds arguments to `parser` for each module in the pipeline.\"\"\"\n    for module_name in PIPELINE:\n        MODULE_MAP[module_name].add_args(parser, f\"{module_name} options\")\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.add_subparser","title":"add_subparser(subparsers, name, module_config)","text":"

Add a subparser to an argparse argument parser.

Parameters:

Name Type Description Default subparsers _SubParsersAction

The subparsers action to which the subparser will be added.

required name str

The name of the subparser.

required module_config ModuleConfig

A ModuleConfig object containing information about the subparser, including a function to execute and a function to add arguments.

required

Returns:

Type Description ArgumentParser

The newly created subparser.

Source code in src/nhssynth/cli/module_setup.py
def add_subparser(\n    subparsers: argparse._SubParsersAction,\n    name: str,\n    module_config: ModuleConfig,\n) -> argparse.ArgumentParser:\n    \"\"\"\n    Add a subparser to an argparse argument parser.\n\n    Args:\n        subparsers: The subparsers action to which the subparser will be added.\n        name: The name of the subparser.\n        module_config: A [`ModuleConfig`][nhssynth.cli.module_setup.ModuleConfig] object containing information about the subparser, including a function to execute and a function to add arguments.\n\n    Returns:\n        The newly created subparser.\n    \"\"\"\n    parent_parsers = get_parent_parsers(name, module_config.common_parsers)\n    parser = subparsers.add_parser(\n        name=name,\n        description=module_config.description,\n        help=module_config.help,\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n        parents=parent_parsers,\n    )\n    if name not in {\"pipeline\", \"config\"}:\n        module_config.add_args(parser, f\"{name} options\")\n    else:\n        module_config.add_args(parser)\n    parser.set_defaults(func=module_config.func)\n    return parser\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.get_parent_parsers","title":"get_parent_parsers(name, module_parsers)","text":"

Get a list of parent parsers for a given module, based on the module's common_parsers attribute.

Source code in src/nhssynth/cli/module_setup.py
def get_parent_parsers(name: str, module_parsers: list[str]) -> list[argparse.ArgumentParser]:\n    \"\"\"Get a list of parent parsers for a given module, based on the module's `common_parsers` attribute.\"\"\"\n    if name in {\"pipeline\", \"config\"}:\n        return [p(name == \"config\") for p in COMMON_PARSERS.values()]\n    elif name == \"dashboard\":\n        return [COMMON_PARSERS[pn](True) for pn in module_parsers]\n    else:\n        return [COMMON_PARSERS[pn]() for pn in module_parsers]\n
"},{"location":"reference/cli/module_setup/#nhssynth.cli.module_setup.run_pipeline","title":"run_pipeline(args)","text":"

Runs the specified pipeline of modules with the passed configuration args.

Source code in src/nhssynth/cli/module_setup.py
def run_pipeline(args: argparse.Namespace) -> None:\n    \"\"\"Runs the specified pipeline of modules with the passed configuration `args`.\"\"\"\n    print(\"Running full pipeline...\")\n    args.modules_to_run = PIPELINE\n    for module_name in PIPELINE:\n        args = MODULE_MAP[module_name](args)\n
"},{"location":"reference/cli/run/","title":"run","text":""},{"location":"reference/common/","title":"common","text":""},{"location":"reference/common/common/","title":"common","text":"

Common functions for all modules.

"},{"location":"reference/common/common/#nhssynth.common.common.set_seed","title":"set_seed(seed=None)","text":"

(Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.

Parameters:

Name Type Description Default seed Optional[int]

The seed to set.

None Source code in src/nhssynth/common/common.py
def set_seed(seed: Optional[int] = None) -> None:\n    \"\"\"\n    (Potentially) set the seed for numpy, torch and random. If no seed is provided, nothing happens.\n\n    Args:\n        seed: The seed to set.\n    \"\"\"\n    if seed:\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        random.seed(seed)\n
"},{"location":"reference/common/constants/","title":"constants","text":"

Define all of the common constants used throughout the project.

"},{"location":"reference/common/debugging/","title":"debugging","text":"

Debugging utilities.

"},{"location":"reference/common/dicts/","title":"dicts","text":"

Common functions for working with dictionaries.

"},{"location":"reference/common/dicts/#nhssynth.common.dicts.filter_dict","title":"filter_dict(d, filter_keys, include=False)","text":"

Given a dictionary, return a new dictionary either including or excluding keys in a given filter set.

Parameters:

Name Type Description Default d dict

A dictionary to filter.

required filter_keys Union[set, list]

A list or set of keys to either include or exclude.

required include bool

Determine whether to return a dictionary including or excluding keys in filter.

False

Returns:

Type Description dict

A filtered dictionary.

Examples:

>>> d = {'a': 1, 'b': 2, 'c': 3}\n>>> filter_dict(d, {'a', 'b'})\n{'c': 3}\n>>> filter_dict(d, {'a', 'b'}, include=True)\n{'a': 1, 'b': 2}\n
Source code in src/nhssynth/common/dicts.py
def filter_dict(d: dict, filter_keys: Union[set, list], include: bool = False) -> dict:\n    \"\"\"\n    Given a dictionary, return a new dictionary either including or excluding keys in a given `filter` set.\n\n    Args:\n        d: A dictionary to filter.\n        filter_keys: A list or set of keys to either include or exclude.\n        include: Determine whether to return a dictionary including or excluding keys in `filter`.\n\n    Returns:\n        A filtered dictionary.\n\n    Examples:\n        >>> d = {'a': 1, 'b': 2, 'c': 3}\n        >>> filter_dict(d, {'a', 'b'})\n        {'c': 3}\n        >>> filter_dict(d, {'a', 'b'}, include=True)\n        {'a': 1, 'b': 2}\n    \"\"\"\n    if include:\n        filtered_keys = set(filter_keys) & set(d.keys())\n    else:\n        filtered_keys = set(d.keys()) - set(filter_keys)\n    return {k: v for k, v in d.items() if k in filtered_keys}\n
"},{"location":"reference/common/dicts/#nhssynth.common.dicts.flatten_dict","title":"flatten_dict(d)","text":"

Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.

Parameters:

Name Type Description Default d dict[str, Any]

A dictionary with potentially nested keys.

required

Returns:

Type Description dict[str, Any]

A flattened dictionary.

Raises:

Type Description ValueError

If duplicate keys are found in the flattened dictionary.

Examples:

>>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}\n>>> flatten_dict(d)\n{'a': 1, 'c': 2, 'e': 3}\n
Source code in src/nhssynth/common/dicts.py
def flatten_dict(d: dict[str, Any]) -> dict[str, Any]:\n    \"\"\"\n    Flatten a dictionary by recursively combining nested keys into a single dictionary until no nested keys remain.\n\n    Args:\n        d: A dictionary with potentially nested keys.\n\n    Returns:\n        A flattened dictionary.\n\n    Raises:\n        ValueError: If duplicate keys are found in the flattened dictionary.\n\n    Examples:\n        >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}\n        >>> flatten_dict(d)\n        {'a': 1, 'c': 2, 'e': 3}\n    \"\"\"\n    items = []\n    for k, v in d.items():\n        if isinstance(v, dict):\n            items.extend(flatten_dict(v).items())\n        else:\n            items.append((k, v))\n    if len(set([p[0] for p in items])) != len(items):\n        raise ValueError(\"Duplicate keys found in flattened dictionary\")\n    return dict(items)\n
"},{"location":"reference/common/dicts/#nhssynth.common.dicts.get_key_by_value","title":"get_key_by_value(d, value)","text":"

Find the first key in a dictionary with a given value.

Parameters:

Name Type Description Default d dict

A dictionary to search through.

required value Any

The value to search for.

required

Returns:

Type Description Union[Any, None]

The first key in d with the value value, or None if no such key exists.

Examples:

>>> d = {'a': 1, 'b': 2, 'c': 1}\n>>> get_key_by_value(d, 2)\n'b'\n>>> get_key_by_value(d, 3)\nNone\n
Source code in src/nhssynth/common/dicts.py
def get_key_by_value(d: dict, value: Any) -> Union[Any, None]:\n    \"\"\"\n    Find the first key in a dictionary with a given value.\n\n    Args:\n        d: A dictionary to search through.\n        value: The value to search for.\n\n    Returns:\n        The first key in `d` with the value `value`, or `None` if no such key exists.\n\n    Examples:\n        >>> d = {'a': 1, 'b': 2, 'c': 1}\n        >>> get_key_by_value(d, 2)\n        'b'\n        >>> get_key_by_value(d, 3)\n        None\n\n    \"\"\"\n    for key, val in d.items():\n        if val == value:\n            return key\n    return None\n
"},{"location":"reference/common/io/","title":"io","text":"

Common building-block functions for handling module input and output.

"},{"location":"reference/common/io/#nhssynth.common.io.check_exists","title":"check_exists(fns, dir)","text":"

Checks if the files in fns exist in dir.

Parameters:

Name Type Description Default fns list[str]

The list of files to check.

required dir Path

The directory the files should exist in.

required

Raises:

Type Description FileNotFoundError

If any of the files in fns do not exist in dir.

Source code in src/nhssynth/common/io.py
def check_exists(fns: list[str], dir: Path) -> None:\n    \"\"\"\n    Checks if the files in `fns` exist in `dir`.\n\n    Args:\n        fns: The list of files to check.\n        dir: The directory the files should exist in.\n\n    Raises:\n        FileNotFoundError: If any of the files in `fns` do not exist in `dir`.\n    \"\"\"\n    for fn in fns:\n        if not (dir / fn).exists():\n            raise FileNotFoundError(f\"File {fn} does not exist at {dir}.\")\n
"},{"location":"reference/common/io/#nhssynth.common.io.consistent_ending","title":"consistent_ending(fn, ending='.pkl', suffix='')","text":"

Ensures that the filename fn ends with ending. If not, removes any existing ending and appends ending.

Parameters:

Name Type Description Default fn str

The filename to check.

required ending str

The desired ending to check for. Default is \".pkl\".

'.pkl' suffix str

A suffix to append to the filename before the ending.

''

Returns:

Type Description str

The filename with the correct ending and potentially an inserted suffix.

Source code in src/nhssynth/common/io.py
def consistent_ending(fn: str, ending: str = \".pkl\", suffix: str = \"\") -> str:\n    \"\"\"\n    Ensures that the filename `fn` ends with `ending`. If not, removes any existing ending and appends `ending`.\n\n    Args:\n        fn: The filename to check.\n        ending: The desired ending to check for. Default is \".pkl\".\n        suffix: A suffix to append to the filename before the ending.\n\n    Returns:\n        The filename with the correct ending and potentially an inserted suffix.\n    \"\"\"\n    path_fn = Path(fn)\n    return str(path_fn.parent / path_fn.stem) + (\"_\" if suffix else \"\") + suffix + ending\n
"},{"location":"reference/common/io/#nhssynth.common.io.consistent_endings","title":"consistent_endings(args)","text":"

Wrapper around consistent_ending to apply it to a list of filenames.

Parameters:

Name Type Description Default args list[Union[str, tuple[str, str], tuple[str, str, str]]]

The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.

required

Returns:

Type Description list[str]

The list of filenames with the correct endings.

Source code in src/nhssynth/common/io.py
def consistent_endings(args: list[Union[str, tuple[str, str], tuple[str, str, str]]]) -> list[str]:\n    \"\"\"\n    Wrapper around `consistent_ending` to apply it to a list of filenames.\n\n    Args:\n        args: The list of filenames to check. Can take the form of a single filename, a pair of a filename and an ending, or a triple of a filename, an ending and a suffix.\n\n    Returns:\n        The list of filenames with the correct endings.\n    \"\"\"\n    return list(consistent_ending(arg) if isinstance(arg, str) else consistent_ending(*arg) for arg in args)\n
"},{"location":"reference/common/io/#nhssynth.common.io.experiment_io","title":"experiment_io(experiment_name, dir_experiments='experiments')","text":"

Create an experiment's directory and return the path.

Parameters:

Name Type Description Default experiment_name str

The name of the experiment.

required dir_experiments str

The name of the directory containing all experiments.

'experiments'

Returns:

Type Description str

The path to the experiment directory.

Source code in src/nhssynth/common/io.py
def experiment_io(experiment_name: str, dir_experiments: str = \"experiments\") -> str:\n    \"\"\"\n    Create an experiment's directory and return the path.\n\n    Args:\n        experiment_name: The name of the experiment.\n        dir_experiments: The name of the directory containing all experiments.\n\n    Returns:\n        The path to the experiment directory.\n    \"\"\"\n    dir_experiment = Path(dir_experiments) / experiment_name\n    dir_experiment.mkdir(parents=True, exist_ok=True)\n    return dir_experiment\n
"},{"location":"reference/common/io/#nhssynth.common.io.potential_suffix","title":"potential_suffix(fn, fn_base)","text":"

Checks if fn is a suffix (starts with an underscore) to append to fn_base, or a filename in its own right.

Parameters:

Name Type Description Default fn str

The filename / potential suffix to append to fn_base.

required fn_base str

The name of the file the suffix would attach to.

required

Returns:

Type Description str

The appropriately processed fn

Source code in src/nhssynth/common/io.py
def potential_suffix(fn: str, fn_base: str) -> str:\n    \"\"\"\n    Checks if `fn` is a suffix (starts with an underscore) to append to `fn_base`, or a filename in its own right.\n\n    Args:\n        fn: The filename / potential suffix to append to `fn_base`.\n        fn_base: The name of the file the suffix would attach to.\n\n    Returns:\n        The appropriately processed `fn`\n    \"\"\"\n    fn_base = Path(fn_base).stem\n    if fn[0] == \"_\":\n        return fn_base + fn\n    else:\n        return fn\n
"},{"location":"reference/common/io/#nhssynth.common.io.potential_suffixes","title":"potential_suffixes(fns, fn_base)","text":"

Wrapper around potential_suffix to apply it to a list of filenames.

Parameters:

Name Type Description Default fns list[str]

The list of filenames / potential suffixes to append to fn_base.

required fn_base str

The name of the file the suffixes would attach to.

required Source code in src/nhssynth/common/io.py
def potential_suffixes(fns: list[str], fn_base: str) -> list[str]:\n    \"\"\"\n    Wrapper around `potential_suffix` to apply it to a list of filenames.\n\n    Args:\n        fns: The list of filenames / potential suffixes to append to `fn_base`.\n        fn_base: The name of the file the suffixes would attach to.\n    \"\"\"\n    return list(potential_suffix(fn, fn_base) for fn in fns)\n
"},{"location":"reference/common/io/#nhssynth.common.io.warn_if_path_supplied","title":"warn_if_path_supplied(fns, dir)","text":"

Warns if the files in fns include directory separators.

Parameters:

Name Type Description Default fns list[str]

The list of files to check.

required dir Path

The directory the files should exist in.

required

Warns:

Type Description UserWarning

when the path to any of the files in fns includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.

Source code in src/nhssynth/common/io.py
def warn_if_path_supplied(fns: list[str], dir: Path) -> None:\n    \"\"\"\n    Warns if the files in `fns` include directory separators.\n\n    Args:\n        fns: The list of files to check.\n        dir: The directory the files should exist in.\n\n    Warnings:\n        UserWarning: when the path to any of the files in `fns` includes directory separators, as this may lead to unintended consequences if the user doesn't realise default directories are pre-specified.\n    \"\"\"\n    for fn in fns:\n        if \"/\" in fn:\n            warnings.warn(\n                f\"Using the path supplied appended to {dir}, i.e. attempting to read data from {dir / fn}\",\n                UserWarning,\n            )\n
"},{"location":"reference/common/strings/","title":"strings","text":"

String manipulation functions.

"},{"location":"reference/common/strings/#nhssynth.common.strings.add_spaces_before_caps","title":"add_spaces_before_caps(string)","text":"

Adds spaces before capital letters in a string if there is a lower-case letter following it.

Parameters:

Name Type Description Default string str

The string to add spaces to.

required

Returns:

Type Description str

The string with spaces added before capital letters.

Examples:

>>> add_spaces_before_caps(\"HelloWorld\")\n'Hello World'\n>>> add_spaces_before_caps(\"HelloWorldAGAIN\")\n'Hello World AGAIN'\n
Source code in src/nhssynth/common/strings.py
def add_spaces_before_caps(string: str) -> str:\n    \"\"\"\n    Adds spaces before capital letters in a string if there is a lower-case letter following it.\n\n    Args:\n        string: The string to add spaces to.\n\n    Returns:\n        The string with spaces added before capital letters.\n\n    Examples:\n        >>> add_spaces_before_caps(\"HelloWorld\")\n        'Hello World'\n        >>> add_spaces_before_caps(\"HelloWorldAGAIN\")\n        'Hello World AGAIN'\n    \"\"\"\n    return \" \".join(re.findall(r\"[a-z]?[A-Z][a-z]+|[A-Z]+(?=[A-Z][a-z]|\\b)\", string))\n
"},{"location":"reference/common/strings/#nhssynth.common.strings.format_timedelta","title":"format_timedelta(start, finish)","text":"

Calculate and prettily format the difference between two calls to time.time().

Parameters:

Name Type Description Default start float

The start time.

required finish float

The finish time.

required

Returns:

Type Description str

A string containing the time difference in a human-readable format.

Source code in src/nhssynth/common/strings.py
def format_timedelta(start: float, finish: float) -> str:\n    \"\"\"\n    Calculate and prettily format the difference between two calls to `time.time()`.\n\n    Args:\n        start: The start time.\n        finish: The finish time.\n\n    Returns:\n        A string containing the time difference in a human-readable format.\n    \"\"\"\n    total = datetime.timedelta(seconds=finish - start)\n    hours, remainder = divmod(total.seconds, 3600)\n    minutes, seconds = divmod(remainder, 60)\n\n    if total.days > 0:\n        delta_str = f\"{total.days}d {hours}h {minutes}m {seconds}s\"\n    elif hours > 0:\n        delta_str = f\"{hours}h {minutes}m {seconds}s\"\n    elif minutes > 0:\n        delta_str = f\"{minutes}m {seconds}s\"\n    else:\n        delta_str = f\"{seconds}s\"\n    return delta_str\n
"},{"location":"reference/modules/","title":"modules","text":""},{"location":"reference/modules/dashboard/","title":"dashboard","text":""},{"location":"reference/modules/dashboard/Upload/","title":"Upload","text":""},{"location":"reference/modules/dashboard/Upload/#nhssynth.modules.dashboard.Upload.get_component","title":"get_component(args, name, component_type, text)","text":"

Generate an upload field and its functionality for a given component of the evaluations.

Parameters:

Name Type Description Default name str

The name of the component as it should be recorded in the session state and as it exists in the args.

required component_type Any

The type of the component (to ensure that only the expected object can be uploaded)

required text str

The human-readable text to display to the user as part of the element.

required Source code in src/nhssynth/modules/dashboard/Upload.py
def get_component(args: argparse.Namespace, name: str, component_type: Any, text: str) -> None:\n    \"\"\"\n    Generate an upload field and its functionality for a given component of the evaluations.\n\n    Args:\n        name: The name of the component as it should be recorded in the session state and as it exists in the args.\n        component_type: The type of the component (to ensure that only the expected object can be uploaded)\n        text: The human-readable text to display to the user as part of the element.\n    \"\"\"\n    uploaded = st.file_uploader(f\"Upload a pickle file containing a {text}\", type=\"pkl\")\n    if getattr(args, name):\n        with open(os.getcwd() + \"/\" + getattr(args, name), \"rb\") as f:\n            loaded = pickle.load(f)\n    if uploaded is not None:\n        loaded = pickle.load(uploaded)\n    if loaded is not None:\n        assert isinstance(loaded, component_type), f\"Uploaded file does not contain a {text}!\"\n        st.session_state[name] = loaded.contents\n        st.success(f\"Loaded {text}!\")\n
"},{"location":"reference/modules/dashboard/Upload/#nhssynth.modules.dashboard.Upload.parse_args","title":"parse_args()","text":"

These arguments allow a user to automatically load the required data for the dashboard from disk.

Returns:

Type Description Namespace

The parsed arguments.

Source code in src/nhssynth/modules/dashboard/Upload.py
def parse_args() -> argparse.Namespace:\n    \"\"\"\n    These arguments allow a user to automatically load the required data for the dashboard from disk.\n\n    Returns:\n        The parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"NHSSynth Evaluation Dashboard\")\n    parser.add_argument(\"--evaluations\", type=str, help=\"Path to a set of evaluations.\")\n    parser.add_argument(\"--experiments\", type=str, help=\"Path to a set of experiments.\")\n    parser.add_argument(\"--synthetic-datasets\", type=str, help=\"Path to a set of synthetic datasets.\")\n    parser.add_argument(\"--typed\", type=str, help=\"Path to a typed real dataset.\")\n    return parser.parse_args()\n
"},{"location":"reference/modules/dashboard/io/","title":"io","text":""},{"location":"reference/modules/dashboard/io/#nhssynth.modules.dashboard.io.check_input_paths","title":"check_input_paths(dir_experiment, fn_dataset, fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default dir_experiment str

The path to the experiment directory.

required fn_dataset str

The base name of the dataset.

required fn_experiments str

The filename of the collection of experiments.

required fn_synthetic_datasets str

The filename of the collection of synthetic datasets.

required fn_evaluations str

The filename of the collection of evaluations.

required

Returns:

Type Description str

The paths

Source code in src/nhssynth/modules/dashboard/io.py
def check_input_paths(\n    dir_experiment: str,\n    fn_dataset: str,\n    fn_typed: str,\n    fn_experiments: str,\n    fn_synthetic_datasets: str,\n    fn_evaluations: str,\n) -> str:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        dir_experiment: The path to the experiment directory.\n        fn_dataset: The base name of the dataset.\n        fn_experiments: The filename of the collection of experiments.\n        fn_synthetic_datasets: The filename of the collection of synthetic datasets.\n        fn_evaluations: The filename of the collection of evaluations.\n\n    Returns:\n        The paths\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations = io.consistent_endings(\n        [fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations]\n    )\n    fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations = io.potential_suffixes(\n        [fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], fn_dataset\n    )\n    io.warn_if_path_supplied([fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], dir_experiment)\n    io.check_exists([fn_typed, fn_experiments, fn_synthetic_datasets, fn_evaluations], dir_experiment)\n    return (\n        dir_experiment / fn_typed,\n        dir_experiment / fn_experiments,\n        dir_experiment / fn_synthetic_datasets,\n        dir_experiment / fn_evaluations,\n    )\n
"},{"location":"reference/modules/dashboard/run/","title":"run","text":""},{"location":"reference/modules/dashboard/utils/","title":"utils","text":""},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.hide_streamlit_content","title":"hide_streamlit_content()","text":"

Hide the footer message and deploy button in Streamlit.

Source code in src/nhssynth/modules/dashboard/utils.py
def hide_streamlit_content() -> None:\n    \"\"\"\n    Hide the footer message and deploy button in Streamlit.\n    \"\"\"\n    hide_streamlit_style = \"\"\"\n    <style>\n    footer {visibility: hidden;}\n    .stDeployButton {visibility: hidden;}\n    </style>\n    \"\"\"\n    st.markdown(hide_streamlit_style, unsafe_allow_html=True)\n
"},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.id_selector","title":"id_selector(df)","text":"

Select an ID from the dataframe to then operate on.

Parameters:

Name Type Description Default df DataFrame

The dataframe to select an ID from.

required

Returns:

Type Description Series

The dataset subset to only the row corresponding to the ID.

Source code in src/nhssynth/modules/dashboard/utils.py
def id_selector(df: pd.DataFrame) -> pd.Series:\n    \"\"\"\n    Select an ID from the dataframe to then operate on.\n\n    Args:\n        df: The dataframe to select an ID from.\n\n    Returns:\n        The dataset subset to only the row corresponding to the ID.\n    \"\"\"\n    architecture = st.sidebar.selectbox(\n        \"Select architecture to display\", df.index.get_level_values(\"architecture\").unique()\n    )\n    # Different architectures may have different numbers of repeats and configs\n    repeats = df.loc[architecture].index.get_level_values(\"repeat\").astype(int).unique()\n    configs = df.loc[architecture].index.get_level_values(\"config\").astype(int).unique()\n    if len(repeats) > 1:\n        repeat = st.sidebar.selectbox(\"Select repeat to display\", repeats)\n    else:\n        repeat = repeats[0]\n    if len(configs) > 1:\n        config = st.sidebar.selectbox(\"Select configuration to display\", configs)\n    else:\n        config = configs[0]\n    return df.loc[(architecture, repeat, config)]\n
"},{"location":"reference/modules/dashboard/utils/#nhssynth.modules.dashboard.utils.subset_selector","title":"subset_selector(df)","text":"

Select a subset of the dataframe to then operate on.

Parameters:

Name Type Description Default df DataFrame

The dataframe to select a subset of.

required

Returns:

Type Description DataFrame

The subset of the dataframe.

Source code in src/nhssynth/modules/dashboard/utils.py
def subset_selector(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Select a subset of the dataframe to then operate on.\n\n    Args:\n        df: The dataframe to select a subset of.\n\n    Returns:\n        The subset of the dataframe.\n    \"\"\"\n    architectures = df.index.get_level_values(\"architecture\").unique().tolist()\n    repeats = df.index.get_level_values(\"repeat\").astype(int).unique().tolist()\n    configs = df.index.get_level_values(\"config\").astype(int).unique().tolist()\n    selected_architectures = st.sidebar.multiselect(\n        \"Select architectures to display\", architectures, default=architectures\n    )\n    selected_repeats = st.sidebar.multiselect(\"Select repeats to display\", repeats, default=repeats[0])\n    selected_configs = st.sidebar.multiselect(\"Select configurations to display\", configs, default=configs)\n    return df.loc[(selected_architectures, selected_repeats, selected_configs)]\n
"},{"location":"reference/modules/dashboard/pages/","title":"pages","text":""},{"location":"reference/modules/dashboard/pages/1_Tables/","title":"1_Tables","text":""},{"location":"reference/modules/dashboard/pages/2_Plots/","title":"2_Plots","text":""},{"location":"reference/modules/dashboard/pages/2_Plots/#nhssynth.modules.dashboard.pages.2_Plots.prepare_for_dimensionality","title":"prepare_for_dimensionality(df)","text":"

Factorize all categorical columns in a dataframe.

Source code in src/nhssynth/modules/dashboard/pages/2_Plots.py
def prepare_for_dimensionality(df: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"Factorize all categorical columns in a dataframe.\"\"\"\n    for col in df.columns:\n        if df[col].dtype == \"object\":\n            df[col] = pd.factorize(df[col])[0]\n        elif df[col].dtype == \"datetime64[ns]\":\n            df[col] = pd.to_numeric(df[col])\n        min_val = df[col].min()\n        max_val = df[col].max()\n        df[col] = (df[col] - min_val) / (max_val - min_val)\n    return df\n
"},{"location":"reference/modules/dashboard/pages/3_Experiment_Configurations/","title":"3_Experiment_Configurations","text":""},{"location":"reference/modules/dataloader/","title":"dataloader","text":""},{"location":"reference/modules/dataloader/constraints/","title":"constraints","text":""},{"location":"reference/modules/dataloader/io/","title":"io","text":""},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.check_input_paths","title":"check_input_paths(fn_input, fn_metadata, dir_data)","text":"

Formats the input filenames and directory for an experiment.

Parameters:

Name Type Description Default fn_input str

The input data filename.

required fn_metadata str

The metadata filename / suffix to append to fn_input.

required dir_data str

The directory that should contain both of the above.

required

Returns:

Type Description tuple[Path, str, str]

A tuple containing the correct directory path, input data filename and metadata filename (used for both in and out).

Warns:

Type Description UserWarning

When the path to fn_input includes directory separators, as this is not supported and may not work as intended.

UserWarning

When the path to fn_metadata includes directory separators, as this is not supported and may not work as intended.

Source code in src/nhssynth/modules/dataloader/io.py
def check_input_paths(\n    fn_input: str,\n    fn_metadata: str,\n    dir_data: str,\n) -> tuple[Path, str, str]:\n    \"\"\"\n    Formats the input filenames and directory for an experiment.\n\n    Args:\n        fn_input: The input data filename.\n        fn_metadata: The metadata filename / suffix to append to `fn_input`.\n        dir_data: The directory that should contain both of the above.\n\n    Returns:\n        A tuple containing the correct directory path, input data filename and metadata filename (used for both in and out).\n\n    Warnings:\n        UserWarning: When the path to `fn_input` includes directory separators, as this is not supported and may not work as intended.\n        UserWarning: When the path to `fn_metadata` includes directory separators, as this is not supported and may not work as intended.\n    \"\"\"\n    fn_input, fn_metadata = io.consistent_endings([(fn_input, \".csv\"), (fn_metadata, \".yaml\")])\n    dir_data = Path(dir_data)\n    fn_metadata = io.potential_suffix(fn_metadata, fn_input)\n    io.warn_if_path_supplied([fn_input, fn_metadata], dir_data)\n    io.check_exists([fn_input], dir_data)\n    return dir_data, fn_input, fn_metadata\n
"},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.check_output_paths","title":"check_output_paths(fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata, dir_experiment)","text":"

Formats the output filenames for an experiment.

Parameters:

Name Type Description Default fn_dataset str

The input data filename.

required fn_typed str

The typed input data filename/suffix to append to fn_dataset.

required fn_transformed str

The transformed output data filename/suffix to append to fn_dataset.

required fn_metatransformer str

The metatransformer filename/suffix to append to fn_dataset.

required fn_constraint_graph str

The constraint graph filename/suffix to append to fn_dataset.

required fn_sdv_metadata str

The SDV metadata filename/suffix to append to fn_dataset.

required dir_experiment Path

The experiment directory to write the outputs to.

required

Returns:

Type Description tuple[str, str, str]

A tuple containing the formatted output filenames.

Warns:

Type Description UserWarning

When any of the filenames include directory separators, as this is not supported and may not work as intended.

Source code in src/nhssynth/modules/dataloader/io.py
def check_output_paths(\n    fn_dataset: str,\n    fn_typed: str,\n    fn_transformed: str,\n    fn_metatransformer: str,\n    fn_constraint_graph: str,\n    fn_sdv_metadata: str,\n    dir_experiment: Path,\n) -> tuple[str, str, str]:\n    \"\"\"\n    Formats the output filenames for an experiment.\n\n    Args:\n        fn_dataset: The input data filename.\n        fn_typed: The typed input data filename/suffix to append to `fn_dataset`.\n        fn_transformed: The transformed output data filename/suffix to append to `fn_dataset`.\n        fn_metatransformer: The metatransformer filename/suffix to append to `fn_dataset`.\n        fn_constraint_graph: The constraint graph filename/suffix to append to `fn_dataset`.\n        fn_sdv_metadata: The SDV metadata filename/suffix to append to `fn_dataset`.\n        dir_experiment: The experiment directory to write the outputs to.\n\n    Returns:\n        A tuple containing the formatted output filenames.\n\n    Warnings:\n        UserWarning: When any of the filenames include directory separators, as this is not supported and may not work as intended.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = io.consistent_endings(\n        [fn_typed, fn_transformed, fn_metatransformer, (fn_constraint_graph, \".html\"), fn_sdv_metadata]\n    )\n    fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = io.potential_suffixes(\n        [fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata], fn_dataset\n    )\n    io.warn_if_path_supplied(\n        [fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata], dir_experiment\n    )\n    return fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata\n
"},{"location":"reference/modules/dataloader/io/#nhssynth.modules.dataloader.io.write_data_outputs","title":"write_data_outputs(metatransformer, fn_dataset, fn_metadata, dir_experiment, args)","text":"

Writes the transformed data and metatransformer to disk.

Parameters:

Name Type Description Default metatransformer MetaTransformer

The metatransformer used to transform the data into its model-ready state.

required fn_dataset str

The base dataset filename.

required fn_metadata str

The metadata filename.

required dir_experiment Path

The experiment directory to write the outputs to.

required args Namespace

The full set of parsed command line arguments.

required

Returns:

Type Description None

The filename of the dataset used.

Source code in src/nhssynth/modules/dataloader/io.py
def write_data_outputs(\n    metatransformer: MetaTransformer,\n    fn_dataset: str,\n    fn_metadata: str,\n    dir_experiment: Path,\n    args: argparse.Namespace,\n) -> None:\n    \"\"\"\n    Writes the transformed data and metatransformer to disk.\n\n    Args:\n        metatransformer: The metatransformer used to transform the data into its model-ready state.\n        fn_dataset: The base dataset filename.\n        fn_metadata: The metadata filename.\n        dir_experiment: The experiment directory to write the outputs to.\n        args: The full set of parsed command line arguments.\n\n    Returns:\n        The filename of the dataset used.\n    \"\"\"\n    fn_dataset, fn_typed, fn_transformed, fn_metatransformer, fn_constraint_graph, fn_sdv_metadata = check_output_paths(\n        fn_dataset,\n        args.typed,\n        args.transformed,\n        args.metatransformer,\n        args.constraint_graph,\n        args.sdv_metadata,\n        dir_experiment,\n    )\n    metatransformer.save_metadata(dir_experiment / fn_metadata, args.collapse_yaml)\n    metatransformer.save_constraint_graphs(dir_experiment / fn_constraint_graph)\n    with open(dir_experiment / fn_typed, \"wb\") as f:\n        pickle.dump(TypedDataset(metatransformer.get_typed_dataset()), f)\n    transformed_dataset = metatransformer.get_transformed_dataset()\n    transformed_dataset.to_pickle(dir_experiment / fn_transformed)\n    if args.write_csv:\n        chunks = np.array_split(transformed_dataset.index, 100)\n        for chunk, subset in enumerate(tqdm(chunks, desc=\"Writing transformed dataset to CSV\", unit=\"chunk\")):\n            if chunk == 0:\n                transformed_dataset.loc[subset].to_csv(\n                    dir_experiment / (fn_transformed[:-3] + \"csv\"), mode=\"w\", index=False\n                )\n            else:\n                transformed_dataset.loc[subset].to_csv(\n                    dir_experiment / (fn_transformed[:-3] + \"csv\"), mode=\"a\", index=False, header=False\n                )\n    with open(dir_experiment / fn_metatransformer, \"wb\") as f:\n        pickle.dump(metatransformer, f)\n    with open(dir_experiment / fn_sdv_metadata, \"wb\") as f:\n        pickle.dump(metatransformer.get_sdv_metadata(), f)\n\n    return fn_dataset\n
"},{"location":"reference/modules/dataloader/metadata/","title":"metadata","text":""},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData","title":"MetaData","text":"Source code in src/nhssynth/modules/dataloader/metadata.py
class MetaData:\n    class ColumnMetaData:\n        def __init__(self, name: str, data: pd.Series, raw: dict) -> None:\n            self.name = name\n            self.dtype: np.dtype = self._validate_dtype(data, raw.get(\"dtype\"))\n            self.categorical: bool = self._validate_categorical(data, raw.get(\"categorical\"))\n            self.missingness_strategy: GenericMissingnessStrategy = self._validate_missingness_strategy(\n                raw.get(\"missingness\")\n            )\n            self.transformer: ColumnTransformer = self._validate_transformer(raw.get(\"transformer\"))\n\n        def _validate_dtype(self, data: pd.Series, dtype_raw: Optional[Union[dict, str]] = None) -> np.dtype:\n            if isinstance(dtype_raw, dict):\n                dtype_name = dtype_raw.pop(\"name\", None)\n            elif isinstance(dtype_raw, str):\n                dtype_name = dtype_raw\n            else:\n                dtype_name = self._infer_dtype(data)\n            try:\n                dtype = np.dtype(dtype_name)\n            except TypeError:\n                warnings.warn(\n                    f\"Invalid dtype specification '{dtype_name}' for column '{self.name}', ignoring dtype for this column\"\n                )\n                dtype = self._infer_dtype(data)\n            if dtype.kind == \"M\":\n                self._setup_datetime_config(data, dtype_raw)\n            elif dtype.kind in [\"f\", \"i\", \"u\"]:\n                self.rounding_scheme = self._validate_rounding_scheme(data, dtype, dtype_raw)\n            return dtype\n\n        def _infer_dtype(self, data: pd.Series) -> np.dtype:\n            return data.dtype.name\n\n        def _infer_datetime_format(self, data: pd.Series) -> str:\n            return _guess_datetime_format_for_array(data[data.notna()].astype(str).to_numpy())\n\n        def _setup_datetime_config(self, data: pd.Series, datetime_config: dict) -> dict:\n            \"\"\"\n            Add keys to `datetime_config` corresponding to args from the `pd.to_datetime` function\n            (see [the docs](https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html))\n            \"\"\"\n            if not isinstance(datetime_config, dict):\n                datetime_config = {}\n            else:\n                datetime_config = filter_dict(datetime_config, {\"format\", \"floor\"}, include=True)\n            if \"format\" not in datetime_config:\n                datetime_config[\"format\"] = self._infer_datetime_format(data)\n            self.datetime_config = datetime_config\n\n        def _validate_rounding_scheme(self, data: pd.Series, dtype: np.dtype, dtype_dict: dict) -> int:\n            if dtype_dict and \"rounding_scheme\" in dtype_dict:\n                return dtype_dict[\"rounding_scheme\"]\n            else:\n                if dtype.kind != \"f\":\n                    return 1.0\n                roundable_data = data[data.notna()]\n                for i in range(np.finfo(dtype).precision):\n                    if (roundable_data.round(i) == roundable_data).all():\n                        return 10**-i\n            return None\n\n        def _validate_categorical(self, data: pd.Series, categorical: Optional[bool] = None) -> bool:\n            if categorical is None:\n                return self._infer_categorical(data)\n            elif not isinstance(categorical, bool):\n                warnings.warn(\n                    f\"Invalid categorical '{categorical}' for column '{self.name}', ignoring categorical for this column\"\n                )\n                return self._infer_categorical(data)\n            else:\n                self.boolean = data.nunique() <= 2\n                return categorical\n\n        def _infer_categorical(self, data: pd.Series) -> bool:\n            self.boolean = data.nunique() <= 2\n            return data.nunique() <= 10 or self.dtype.kind == \"O\"\n\n        def _validate_missingness_strategy(self, missingness_strategy: Optional[Union[dict, str]]) -> tuple[str, dict]:\n            if not missingness_strategy:\n                return None\n            if isinstance(missingness_strategy, dict):\n                impute = missingness_strategy.get(\"impute\", None)\n                strategy = \"impute\" if impute else missingness_strategy.get(\"strategy\", None)\n            else:\n                strategy = missingness_strategy\n            if (\n                strategy not in MISSINGNESS_STRATEGIES\n                or (strategy == \"impute\" and impute == \"mean\" and self.dtype.kind != \"f\")\n                or (strategy == \"impute\" and not impute)\n            ):\n                warnings.warn(\n                    f\"Invalid missingness strategy '{missingness_strategy}' for column '{self.name}', ignoring missingness strategy for this column\"\n                )\n                return None\n            return (\n                MISSINGNESS_STRATEGIES[strategy](impute) if strategy == \"impute\" else MISSINGNESS_STRATEGIES[strategy]()\n            )\n\n        def _validate_transformer(self, transformer: Optional[Union[dict, str]] = {}) -> tuple[str, dict]:\n            # if transformer is neither a dict nor a str statement below will raise a TypeError\n            if isinstance(transformer, dict):\n                self.transformer_name = transformer.get(\"name\")\n                self.transformer_config = filter_dict(transformer, \"name\")\n            elif isinstance(transformer, str):\n                self.transformer_name = transformer\n                self.transformer_config = {}\n            else:\n                if transformer is not None:\n                    warnings.warn(\n                        f\"Invalid transformer config '{transformer}' for column '{self.name}', ignoring transformer for this column\"\n                    )\n                self.transformer_name = None\n                self.transformer_config = {}\n            if not self.transformer_name:\n                return self._infer_transformer()\n            else:\n                try:\n                    return eval(self.transformer_name)(**self.transformer_config)\n                except NameError:\n                    warnings.warn(\n                        f\"Invalid transformer '{self.transformer_name}' or config '{self.transformer_config}' for column '{self.name}', ignoring transformer for this column\"\n                    )\n                    return self._infer_transformer()\n\n        def _infer_transformer(self) -> ColumnTransformer:\n            if self.categorical:\n                transformer = OHECategoricalTransformer(**self.transformer_config)\n            else:\n                transformer = ClusterContinuousTransformer(**self.transformer_config)\n            if self.dtype.kind == \"M\":\n                transformer = DatetimeTransformer(transformer)\n            return transformer\n\n    def __init__(self, data: pd.DataFrame, metadata: Optional[dict] = {}):\n        self.columns: pd.Index = data.columns\n        self.raw_metadata: dict = metadata\n        if set(self.raw_metadata[\"columns\"].keys()) - set(self.columns):\n            raise ValueError(\"Metadata contains keys that do not appear amongst the columns.\")\n        self.dropped_columns = [cn for cn in self.columns if self.raw_metadata[\"columns\"].get(cn, None) == \"drop\"]\n        self.columns = self.columns.drop(self.dropped_columns)\n        self._metadata = {\n            cn: self.ColumnMetaData(cn, data[cn], self.raw_metadata[\"columns\"].get(cn, {})) for cn in self.columns\n        }\n        self.constraints = ConstraintGraph(self.raw_metadata.get(\"constraints\", []), self.columns, self._metadata)\n\n    def __getitem__(self, key: str) -> dict[str, Any]:\n        return self._metadata[key]\n\n    def __iter__(self) -> Iterator:\n        return iter(self._metadata.values())\n\n    def __repr__(self) -> None:\n        return yaml.dump(self._metadata, default_flow_style=False, sort_keys=False)\n\n    @classmethod\n    def from_path(cls, data: pd.DataFrame, path_str: str):\n        \"\"\"\n        Instantiate a MetaData object from a YAML file via a specified path.\n\n        Args:\n            data: The data to be used to infer / validate the metadata.\n            path_str: The path to the metadata YAML file.\n\n        Returns:\n            The metadata object.\n        \"\"\"\n        path = pathlib.Path(path_str)\n        if path.exists():\n            with open(path) as stream:\n                metadata = yaml.safe_load(stream)\n            # Filter out the expanded alias/anchor group as it is not needed\n            metadata = filter_dict(metadata, {\"column_types\"})\n        else:\n            warnings.warn(f\"No metadata found at {path}...\")\n            metadata = {\"columns\": {}}\n        return cls(data, metadata)\n\n    def _collapse(self, metadata: dict) -> dict:\n        \"\"\"\n        Given a metadata dictionary, rewrite to collapse duplicate column types in order to leverage YAML anchors and shrink the file.\n\n        Args:\n            metadata: The metadata dictionary to be rewritten.\n\n        Returns:\n            A rewritten metadata dictionary with collapsed column types and transformers.\n                The returned dictionary has the following structure:\n                {\n                    \"column_types\": dict,\n                    **metadata  # one entry for each column in \"columns\" that now reference the dicts above\n                }\n                - \"column_types\" is a dictionary mapping column type indices to column type configurations.\n                - \"**metadata\" contains the original metadata dictionary, with column types rewritten to use the indices and \"column_types\".\n        \"\"\"\n        c_index = 1\n        column_types = {}\n        column_type_counts = {}\n        for cn, cd in metadata[\"columns\"].items():\n            if cd not in column_types.values():\n                column_types[c_index] = cd if isinstance(cd, str) else cd.copy()\n                column_type_counts[c_index] = 1\n                c_index += 1\n            else:\n                cix = get_key_by_value(column_types, cd)\n                column_type_counts[cix] += 1\n\n        for cn, cd in metadata[\"columns\"].items():\n            cix = get_key_by_value(column_types, cd)\n            if column_type_counts[cix] > 1:\n                metadata[\"columns\"][cn] = column_types[cix]\n            else:\n                column_types.pop(cix)\n\n        return {\"column_types\": {i + 1: x for i, x in enumerate(column_types.values())}, **metadata}\n\n    def _assemble(self, collapse_yaml: bool) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Rearrange the metadata into a dictionary that can be written to a YAML file.\n\n        Args:\n            collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n\n        Returns:\n            A dictionary containing the assembled metadata.\n        \"\"\"\n        assembled_metadata = {\n            \"columns\": {\n                cn: {\n                    \"dtype\": (\n                        cmd.dtype.name\n                        if not hasattr(cmd, \"datetime_config\")\n                        else {\"name\": cmd.dtype.name, **cmd.datetime_config}\n                    ),\n                    \"categorical\": cmd.categorical,\n                }\n                for cn, cmd in self._metadata.items()\n            }\n        }\n        # We loop through the base dict above to add other parts if they are present in the metadata\n        for cn, cmd in self._metadata.items():\n            if cmd.missingness_strategy:\n                assembled_metadata[\"columns\"][cn][\"missingness\"] = (\n                    cmd.missingness_strategy.name\n                    if cmd.missingness_strategy.name != \"impute\"\n                    else {\"name\": cmd.missingness_strategy.name, \"impute\": cmd.missingness_strategy.impute}\n                )\n            if cmd.transformer_config:\n                assembled_metadata[\"columns\"][cn][\"transformer\"] = {\n                    **cmd.transformer_config,\n                    \"name\": cmd.transformer.__class__.__name__,\n                }\n\n        # Add back the dropped_columns not present in the metadata\n        if self.dropped_columns:\n            assembled_metadata[\"columns\"].update({cn: \"drop\" for cn in self.dropped_columns})\n\n        if collapse_yaml:\n            assembled_metadata = self._collapse(assembled_metadata)\n\n        # We add the constraints section after all of the formatting and processing above\n        # In general, the constraints are kept the same as the input (provided they passed validation)\n        # If `collapse_yaml` is specified, we output the minimum set of equivalent constraints\n        if self.constraints:\n            assembled_metadata[\"constraints\"] = (\n                [str(c) for c in self.constraints.minimal_constraints]\n                if collapse_yaml\n                else self.constraints.raw_constraint_strings\n            )\n        return assembled_metadata\n\n    def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:\n        \"\"\"\n        Writes metadata to a YAML file.\n\n        Args:\n            path: The path at which to write the metadata YAML file.\n            collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n        \"\"\"\n        with open(path, \"w\") as yaml_file:\n            yaml.safe_dump(\n                self._assemble(collapse_yaml),\n                yaml_file,\n                default_flow_style=False,\n                sort_keys=False,\n            )\n\n    def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:\n        \"\"\"\n        Map combinations of our metadata implementation to SDV's as required by SDMetrics.\n\n        Returns:\n            A dictionary containing the SDV metadata.\n        \"\"\"\n        sdv_metadata = {\n            \"columns\": {\n                cn: {\n                    \"sdtype\": (\n                        \"boolean\"\n                        if cmd.boolean\n                        else \"categorical\" if cmd.categorical else \"datetime\" if cmd.dtype.kind == \"M\" else \"numerical\"\n                    ),\n                }\n                for cn, cmd in self._metadata.items()\n            }\n        }\n        for cn, cmd in self._metadata.items():\n            if cmd.dtype.kind == \"M\":\n                sdv_metadata[\"columns\"][cn][\"format\"] = cmd.datetime_config[\"format\"]\n        return sdv_metadata\n\n    def save_constraint_graphs(self, path: pathlib.Path) -> None:\n        \"\"\"\n        Output the constraint graphs as HTML files.\n\n        Args:\n            path: The path at which to write the constraint graph HTML files.\n        \"\"\"\n        self.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.ColumnMetaData","title":"ColumnMetaData","text":"Source code in src/nhssynth/modules/dataloader/metadata.py
class ColumnMetaData:\n    def __init__(self, name: str, data: pd.Series, raw: dict) -> None:\n        self.name = name\n        self.dtype: np.dtype = self._validate_dtype(data, raw.get(\"dtype\"))\n        self.categorical: bool = self._validate_categorical(data, raw.get(\"categorical\"))\n        self.missingness_strategy: GenericMissingnessStrategy = self._validate_missingness_strategy(\n            raw.get(\"missingness\")\n        )\n        self.transformer: ColumnTransformer = self._validate_transformer(raw.get(\"transformer\"))\n\n    def _validate_dtype(self, data: pd.Series, dtype_raw: Optional[Union[dict, str]] = None) -> np.dtype:\n        if isinstance(dtype_raw, dict):\n            dtype_name = dtype_raw.pop(\"name\", None)\n        elif isinstance(dtype_raw, str):\n            dtype_name = dtype_raw\n        else:\n            dtype_name = self._infer_dtype(data)\n        try:\n            dtype = np.dtype(dtype_name)\n        except TypeError:\n            warnings.warn(\n                f\"Invalid dtype specification '{dtype_name}' for column '{self.name}', ignoring dtype for this column\"\n            )\n            dtype = self._infer_dtype(data)\n        if dtype.kind == \"M\":\n            self._setup_datetime_config(data, dtype_raw)\n        elif dtype.kind in [\"f\", \"i\", \"u\"]:\n            self.rounding_scheme = self._validate_rounding_scheme(data, dtype, dtype_raw)\n        return dtype\n\n    def _infer_dtype(self, data: pd.Series) -> np.dtype:\n        return data.dtype.name\n\n    def _infer_datetime_format(self, data: pd.Series) -> str:\n        return _guess_datetime_format_for_array(data[data.notna()].astype(str).to_numpy())\n\n    def _setup_datetime_config(self, data: pd.Series, datetime_config: dict) -> dict:\n        \"\"\"\n        Add keys to `datetime_config` corresponding to args from the `pd.to_datetime` function\n        (see [the docs](https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html))\n        \"\"\"\n        if not isinstance(datetime_config, dict):\n            datetime_config = {}\n        else:\n            datetime_config = filter_dict(datetime_config, {\"format\", \"floor\"}, include=True)\n        if \"format\" not in datetime_config:\n            datetime_config[\"format\"] = self._infer_datetime_format(data)\n        self.datetime_config = datetime_config\n\n    def _validate_rounding_scheme(self, data: pd.Series, dtype: np.dtype, dtype_dict: dict) -> int:\n        if dtype_dict and \"rounding_scheme\" in dtype_dict:\n            return dtype_dict[\"rounding_scheme\"]\n        else:\n            if dtype.kind != \"f\":\n                return 1.0\n            roundable_data = data[data.notna()]\n            for i in range(np.finfo(dtype).precision):\n                if (roundable_data.round(i) == roundable_data).all():\n                    return 10**-i\n        return None\n\n    def _validate_categorical(self, data: pd.Series, categorical: Optional[bool] = None) -> bool:\n        if categorical is None:\n            return self._infer_categorical(data)\n        elif not isinstance(categorical, bool):\n            warnings.warn(\n                f\"Invalid categorical '{categorical}' for column '{self.name}', ignoring categorical for this column\"\n            )\n            return self._infer_categorical(data)\n        else:\n            self.boolean = data.nunique() <= 2\n            return categorical\n\n    def _infer_categorical(self, data: pd.Series) -> bool:\n        self.boolean = data.nunique() <= 2\n        return data.nunique() <= 10 or self.dtype.kind == \"O\"\n\n    def _validate_missingness_strategy(self, missingness_strategy: Optional[Union[dict, str]]) -> tuple[str, dict]:\n        if not missingness_strategy:\n            return None\n        if isinstance(missingness_strategy, dict):\n            impute = missingness_strategy.get(\"impute\", None)\n            strategy = \"impute\" if impute else missingness_strategy.get(\"strategy\", None)\n        else:\n            strategy = missingness_strategy\n        if (\n            strategy not in MISSINGNESS_STRATEGIES\n            or (strategy == \"impute\" and impute == \"mean\" and self.dtype.kind != \"f\")\n            or (strategy == \"impute\" and not impute)\n        ):\n            warnings.warn(\n                f\"Invalid missingness strategy '{missingness_strategy}' for column '{self.name}', ignoring missingness strategy for this column\"\n            )\n            return None\n        return (\n            MISSINGNESS_STRATEGIES[strategy](impute) if strategy == \"impute\" else MISSINGNESS_STRATEGIES[strategy]()\n        )\n\n    def _validate_transformer(self, transformer: Optional[Union[dict, str]] = {}) -> tuple[str, dict]:\n        # if transformer is neither a dict nor a str statement below will raise a TypeError\n        if isinstance(transformer, dict):\n            self.transformer_name = transformer.get(\"name\")\n            self.transformer_config = filter_dict(transformer, \"name\")\n        elif isinstance(transformer, str):\n            self.transformer_name = transformer\n            self.transformer_config = {}\n        else:\n            if transformer is not None:\n                warnings.warn(\n                    f\"Invalid transformer config '{transformer}' for column '{self.name}', ignoring transformer for this column\"\n                )\n            self.transformer_name = None\n            self.transformer_config = {}\n        if not self.transformer_name:\n            return self._infer_transformer()\n        else:\n            try:\n                return eval(self.transformer_name)(**self.transformer_config)\n            except NameError:\n                warnings.warn(\n                    f\"Invalid transformer '{self.transformer_name}' or config '{self.transformer_config}' for column '{self.name}', ignoring transformer for this column\"\n                )\n                return self._infer_transformer()\n\n    def _infer_transformer(self) -> ColumnTransformer:\n        if self.categorical:\n            transformer = OHECategoricalTransformer(**self.transformer_config)\n        else:\n            transformer = ClusterContinuousTransformer(**self.transformer_config)\n        if self.dtype.kind == \"M\":\n            transformer = DatetimeTransformer(transformer)\n        return transformer\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.from_path","title":"from_path(data, path_str) classmethod","text":"

Instantiate a MetaData object from a YAML file via a specified path.

Parameters:

Name Type Description Default data DataFrame

The data to be used to infer / validate the metadata.

required path_str str

The path to the metadata YAML file.

required

Returns:

Type Description

The metadata object.

Source code in src/nhssynth/modules/dataloader/metadata.py
@classmethod\ndef from_path(cls, data: pd.DataFrame, path_str: str):\n    \"\"\"\n    Instantiate a MetaData object from a YAML file via a specified path.\n\n    Args:\n        data: The data to be used to infer / validate the metadata.\n        path_str: The path to the metadata YAML file.\n\n    Returns:\n        The metadata object.\n    \"\"\"\n    path = pathlib.Path(path_str)\n    if path.exists():\n        with open(path) as stream:\n            metadata = yaml.safe_load(stream)\n        # Filter out the expanded alias/anchor group as it is not needed\n        metadata = filter_dict(metadata, {\"column_types\"})\n    else:\n        warnings.warn(f\"No metadata found at {path}...\")\n        metadata = {\"columns\": {}}\n    return cls(data, metadata)\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.get_sdv_metadata","title":"get_sdv_metadata()","text":"

Map combinations of our metadata implementation to SDV's as required by SDMetrics.

Returns:

Type Description dict[str, dict[str, dict[str, str]]]

A dictionary containing the SDV metadata.

Source code in src/nhssynth/modules/dataloader/metadata.py
def get_sdv_metadata(self) -> dict[str, dict[str, dict[str, str]]]:\n    \"\"\"\n    Map combinations of our metadata implementation to SDV's as required by SDMetrics.\n\n    Returns:\n        A dictionary containing the SDV metadata.\n    \"\"\"\n    sdv_metadata = {\n        \"columns\": {\n            cn: {\n                \"sdtype\": (\n                    \"boolean\"\n                    if cmd.boolean\n                    else \"categorical\" if cmd.categorical else \"datetime\" if cmd.dtype.kind == \"M\" else \"numerical\"\n                ),\n            }\n            for cn, cmd in self._metadata.items()\n        }\n    }\n    for cn, cmd in self._metadata.items():\n        if cmd.dtype.kind == \"M\":\n            sdv_metadata[\"columns\"][cn][\"format\"] = cmd.datetime_config[\"format\"]\n    return sdv_metadata\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.save","title":"save(path, collapse_yaml)","text":"

Writes metadata to a YAML file.

Parameters:

Name Type Description Default path Path

The path at which to write the metadata YAML file.

required collapse_yaml bool

A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.

required Source code in src/nhssynth/modules/dataloader/metadata.py
def save(self, path: pathlib.Path, collapse_yaml: bool) -> None:\n    \"\"\"\n    Writes metadata to a YAML file.\n\n    Args:\n        path: The path at which to write the metadata YAML file.\n        collapse_yaml: A boolean indicating whether to collapse the YAML representation of the metadata, reducing duplication.\n    \"\"\"\n    with open(path, \"w\") as yaml_file:\n        yaml.safe_dump(\n            self._assemble(collapse_yaml),\n            yaml_file,\n            default_flow_style=False,\n            sort_keys=False,\n        )\n
"},{"location":"reference/modules/dataloader/metadata/#nhssynth.modules.dataloader.metadata.MetaData.save_constraint_graphs","title":"save_constraint_graphs(path)","text":"

Output the constraint graphs as HTML files.

Parameters:

Name Type Description Default path Path

The path at which to write the constraint graph HTML files.

required Source code in src/nhssynth/modules/dataloader/metadata.py
def save_constraint_graphs(self, path: pathlib.Path) -> None:\n    \"\"\"\n    Output the constraint graphs as HTML files.\n\n    Args:\n        path: The path at which to write the constraint graph HTML files.\n    \"\"\"\n    self.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metatransformer/","title":"metatransformer","text":""},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer","title":"MetaTransformer","text":"

The metatransformer is responsible for transforming input dataset into a format that can be used by the model module, and for transforming this module's output back to the original format of the input dataset.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata Optional[MetaData]

Optionally, a MetaData object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.

None missingness_strategy Optional[str]

The missingness strategy to use. Defaults to augmenting missing values in the data, see the missingness strategies for more information.

'augment' impute_value Optional[Any]

Only used when missingness_strategy is set to 'impute'. The value to use when imputing missing values in the data.

None

After calling MetaTransformer.apply(), the following attributes and methods will be available:

Attributes:

Name Type Description typed_dataset DataFrame

The dataset with the dtypes applied.

post_missingness_strategy_dataset DataFrame

The dataset with the missingness strategies applied.

transformed_dataset DataFrame

The transformed dataset.

single_column_indices list[int]

The indices of the columns that were transformed into a single column.

multi_column_indices list[list[int]]

The indices of the columns that were transformed into multiple columns.

Methods:

  • get_typed_dataset(): Returns the typed dataset.
  • get_prepared_dataset(): Returns the dataset with the missingness strategies applied.
  • get_transformed_dataset(): Returns the transformed dataset.
  • get_multi_and_single_column_indices(): Returns the indices of the columns that were transformed into one or multiple column(s).
  • get_sdv_metadata(): Returns the metadata in the correct format for SDMetrics.
  • save_metadata(): Saves the metadata to a file.
  • save_constraint_graphs(): Saves the constraint graphs to a file.

Note that mt.apply is a helper function that runs mt.apply_dtypes, mt.apply_missingness_strategy and mt.transform in sequence. This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
class MetaTransformer:\n    \"\"\"\n    The metatransformer is responsible for transforming input dataset into a format that can be used by the `model` module, and for transforming\n    this module's output back to the original format of the input dataset.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata: Optionally, a [`MetaData`][nhssynth.modules.dataloader.metadata.MetaData] object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.\n        missingness_strategy: The missingness strategy to use. Defaults to augmenting missing values in the data, see [the missingness strategies][nhssynth.modules.dataloader.missingness] for more information.\n        impute_value: Only used when `missingness_strategy` is set to 'impute'. The value to use when imputing missing values in the data.\n\n    After calling `MetaTransformer.apply()`, the following attributes and methods will be available:\n\n    Attributes:\n        typed_dataset (pd.DataFrame): The dataset with the dtypes applied.\n        post_missingness_strategy_dataset (pd.DataFrame): The dataset with the missingness strategies applied.\n        transformed_dataset (pd.DataFrame): The transformed dataset.\n        single_column_indices (list[int]): The indices of the columns that were transformed into a single column.\n        multi_column_indices (list[list[int]]): The indices of the columns that were transformed into multiple columns.\n\n    **Methods:**\n\n    - `get_typed_dataset()`: Returns the typed dataset.\n    - `get_prepared_dataset()`: Returns the dataset with the missingness strategies applied.\n    - `get_transformed_dataset()`: Returns the transformed dataset.\n    - `get_multi_and_single_column_indices()`: Returns the indices of the columns that were transformed into one or multiple column(s).\n    - `get_sdv_metadata()`: Returns the metadata in the correct format for SDMetrics.\n    - `save_metadata()`: Saves the metadata to a file.\n    - `save_constraint_graphs()`: Saves the constraint graphs to a file.\n\n    Note that `mt.apply` is a helper function that runs `mt.apply_dtypes`, `mt.apply_missingness_strategy` and `mt.transform` in sequence.\n    This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: pd.DataFrame,\n        metadata: Optional[MetaData] = None,\n        missingness_strategy: Optional[str] = \"augment\",\n        impute_value: Optional[Any] = None,\n    ):\n        self._raw_dataset: pd.DataFrame = dataset\n        self._metadata: MetaData = metadata or MetaData(dataset)\n        if missingness_strategy == \"impute\":\n            assert (\n                impute_value is not None\n            ), \"`impute_value` of the `MetaTransformer` must be specified (via the --impute flag) when using the imputation missingness strategy\"\n            self._impute_value = impute_value\n        self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy]\n\n    @classmethod\n    def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:\n        \"\"\"\n        Instantiates a MetaTransformer from a metadata file via a provided path.\n\n        Args:\n            dataset: The raw input DataFrame.\n            metadata_path: The path to the metadata file.\n\n        Returns:\n            A MetaTransformer object.\n        \"\"\"\n        return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)\n\n    @classmethod\n    def from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:\n        \"\"\"\n        Instantiates a MetaTransformer from a metadata dictionary.\n\n        Args:\n            dataset: The raw input DataFrame.\n            metadata: A dictionary of raw metadata.\n\n        Returns:\n            A MetaTransformer object.\n        \"\"\"\n        return cls(dataset, MetaData(dataset, metadata), **kwargs)\n\n    def drop_columns(self) -> None:\n        \"\"\"\n        Drops columns from the dataset that are not in the `MetaData`.\n        \"\"\"\n        self._raw_dataset = self._raw_dataset[self._metadata.columns]\n\n    def _apply_rounding_scheme(self, working_column: pd.Series, rounding_scheme: float) -> pd.Series:\n        \"\"\"\n        A rounding scheme takes the form of the smallest value that should be rounded to 0, i.e. 0.01 for 2dp.\n        We first round to the nearest multiple in the standard way, through dividing, rounding and then multiplying.\n        However, this can lead to floating point errors, so we then round to the number of decimal places required by the rounding scheme.\n\n        e.g. `np.round(0.15 / 0.1) * 0.1` will erroneously return 0.1.\n\n        Args:\n            working_column: The column to apply the rounding scheme to.\n            rounding_scheme: The rounding scheme to apply.\n\n        Returns:\n            The column with the rounding scheme applied.\n        \"\"\"\n        working_column = np.round(working_column / rounding_scheme) * rounding_scheme\n        return working_column.round(max(0, int(np.ceil(np.log10(1 / rounding_scheme)))))\n\n    def _apply_dtype(\n        self,\n        working_column: pd.Series,\n        column_metadata: MetaData.ColumnMetaData,\n    ) -> pd.Series:\n        \"\"\"\n        Given a `working_column`, the dtype specified in the `column_metadata` is applied to it.\n         - Datetime columns are floored, and their format is inferred.\n         - Rounding schemes are applied to numeric columns if specified.\n         - Columns with missing values have their dtype converted to the pandas equivalent to allow for NA values.\n\n        Args:\n            working_column: The column to apply the dtype to.\n            column_metadata: The metadata for the column.\n\n        Returns:\n            The column with the dtype applied.\n        \"\"\"\n        dtype = column_metadata.dtype\n        try:\n            if dtype.kind == \"M\":\n                working_column = pd.to_datetime(working_column, format=column_metadata.datetime_config.get(\"format\"))\n                if column_metadata.datetime_config.get(\"floor\"):\n                    working_column = working_column.dt.floor(column_metadata.datetime_config.get(\"floor\"))\n                    column_metadata.datetime_config[\"format\"] = column_metadata._infer_datetime_format(working_column)\n                return working_column\n            else:\n                if hasattr(column_metadata, \"rounding_scheme\") and column_metadata.rounding_scheme is not None:\n                    working_column = self._apply_rounding_scheme(working_column, column_metadata.rounding_scheme)\n                # If there are missing values in the column, we need to use the pandas equivalent of the dtype to allow for NA values\n                if working_column.isnull().any() and dtype.kind in [\"i\", \"u\", \"f\"]:\n                    return working_column.astype(dtype.name.capitalize())\n                else:\n                    return working_column.astype(dtype)\n        except ValueError:\n            raise ValueError(f\"{sys.exc_info()[1]}\\nError applying dtype '{dtype}' to column '{working_column.name}'\")\n\n    def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Applies dtypes from the metadata to `dataset`.\n\n        Returns:\n            The dataset with the dtypes applied.\n        \"\"\"\n        working_data = data.copy()\n        for column_metadata in self._metadata:\n            working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)\n        return working_data\n\n    def apply_missingness_strategy(self) -> pd.DataFrame:\n        \"\"\"\n        Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or\n        column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness\n        is not resolved, instead a new column / value is added for later transformation.\n\n        Returns:\n            The dataset with the missingness strategies applied.\n        \"\"\"\n        working_data = self.typed_dataset.copy()\n        for column_metadata in self._metadata:\n            if not column_metadata.missingness_strategy:\n                column_metadata.missingness_strategy = (\n                    self._missingness_strategy(self._impute_value)\n                    if hasattr(self, \"_impute_value\")\n                    else self._missingness_strategy()\n                )\n            if not working_data[column_metadata.name].isnull().any():\n                continue\n            working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)\n        return working_data\n\n    # def apply_constraints(self) -> pd.DataFrame:\n    #     working_data = self.post_missingness_strategy_dataset.copy()\n    #     for constraint in self._metadata.constraints:\n    #         working_data = constraint.apply(working_data)\n    #     return working_data\n\n    def _get_missingness_carrier(self, column_metadata: MetaData.ColumnMetaData) -> Union[pd.Series, Any]:\n        \"\"\"\n        In the case of the `AugmentMissingnessStrategy`, a `missingness_carrier` has been determined for each column.\n        For continuous columns this is an indicator column for the presence of NaN values.\n        For categorical columns this is the value to be used to represent missingness as a category.\n\n        Args:\n            column_metadata: The metadata for the column.\n\n        Returns:\n            The missingness carrier for the column.\n        \"\"\"\n        missingness_carrier = getattr(column_metadata.missingness_strategy, \"missingness_carrier\", None)\n        if missingness_carrier in self.post_missingness_strategy_dataset.columns:\n            return self.post_missingness_strategy_dataset[missingness_carrier]\n        else:\n            return missingness_carrier\n\n    def transform(self) -> pd.DataFrame:\n        \"\"\"\n        Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.\n\n        Returns:\n            The transformed dataset.\n        \"\"\"\n        transformed_columns = []\n        self.single_column_indices = []\n        self.multi_column_indices = []\n        col_counter = 0\n        working_data = self.post_missingness_strategy_dataset.copy()\n\n        # iteratively build the transformed df\n        for column_metadata in tqdm(\n            self._metadata, desc=\"Transforming data\", unit=\"column\", total=len(self._metadata.columns)\n        ):\n            missingness_carrier = self._get_missingness_carrier(column_metadata)\n            transformed_data = column_metadata.transformer.apply(\n                working_data[column_metadata.name], missingness_carrier\n            )\n            transformed_columns.append(transformed_data)\n\n            # track single and multi column indices to supply to the model\n            if isinstance(transformed_data, pd.DataFrame) and transformed_data.shape[1] > 1:\n                num_to_add = transformed_data.shape[1]\n                if not column_metadata.categorical:\n                    self.single_column_indices.append(col_counter)\n                    col_counter += 1\n                    num_to_add -= 1\n                self.multi_column_indices.append(list(range(col_counter, col_counter + num_to_add)))\n                col_counter += num_to_add\n            else:\n                self.single_column_indices.append(col_counter)\n                col_counter += 1\n\n        return pd.concat(transformed_columns, axis=1)\n\n    def apply(self) -> pd.DataFrame:\n        \"\"\"\n        Applies the various steps of the MetaTransformer to a passed DataFrame.\n\n        Returns:\n            The transformed dataset.\n        \"\"\"\n        self.drop_columns()\n        self.typed_dataset = self.apply_dtypes(self._raw_dataset)\n        self.post_missingness_strategy_dataset = self.apply_missingness_strategy()\n        # self.constrained_dataset = self.apply_constraints()\n        self.transformed_dataset = self.transform()\n        return self.transformed_dataset\n\n    def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Reverses the transformation applied by the MetaTransformer.\n\n        Args:\n            dataset: The transformed dataset.\n\n        Returns:\n            The original dataset.\n        \"\"\"\n        for column_metadata in self._metadata:\n            dataset = column_metadata.transformer.revert(dataset)\n        return self.apply_dtypes(dataset)\n\n    def get_typed_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"typed_dataset\"):\n            raise ValueError(\n                \"The typed dataset has not yet been created. Call `mt.apply()` (or `mt.apply_dtypes()`) first.\"\n            )\n        return self.typed_dataset\n\n    def get_prepared_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"prepared_dataset\"):\n            raise ValueError(\n                \"The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.apply_missingness_strategy()`) first.\"\n            )\n        return self.prepared_dataset\n\n    def get_transformed_dataset(self) -> pd.DataFrame:\n        if not hasattr(self, \"transformed_dataset\"):\n            raise ValueError(\n                \"The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n            )\n        return self.transformed_dataset\n\n    def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:\n        \"\"\"\n        Returns the indices of the columns that were transformed into one or multiple column(s).\n\n        Returns:\n            A tuple containing the indices of the single and multi columns.\n        \"\"\"\n        if not hasattr(self, \"multi_column_indices\") or not hasattr(self, \"single_column_indices\"):\n            raise ValueError(\n                \"The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n            )\n        return self.multi_column_indices, self.single_column_indices\n\n    def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.\n\n        Returns:\n            The metadata in the correct format for SDMetrics.\n        \"\"\"\n        return self._metadata.get_sdv_metadata()\n\n    def save_metadata(self, path: pathlib.Path, collapse_yaml: bool = False) -> None:\n        return self._metadata.save(path, collapse_yaml)\n\n    def save_constraint_graphs(self, path: pathlib.Path) -> None:\n        return self._metadata.constraints._output_graphs_html(path)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply","title":"apply()","text":"

Applies the various steps of the MetaTransformer to a passed DataFrame.

Returns:

Type Description DataFrame

The transformed dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply(self) -> pd.DataFrame:\n    \"\"\"\n    Applies the various steps of the MetaTransformer to a passed DataFrame.\n\n    Returns:\n        The transformed dataset.\n    \"\"\"\n    self.drop_columns()\n    self.typed_dataset = self.apply_dtypes(self._raw_dataset)\n    self.post_missingness_strategy_dataset = self.apply_missingness_strategy()\n    # self.constrained_dataset = self.apply_constraints()\n    self.transformed_dataset = self.transform()\n    return self.transformed_dataset\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply_dtypes","title":"apply_dtypes(data)","text":"

Applies dtypes from the metadata to dataset.

Returns:

Type Description DataFrame

The dataset with the dtypes applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Applies dtypes from the metadata to `dataset`.\n\n    Returns:\n        The dataset with the dtypes applied.\n    \"\"\"\n    working_data = data.copy()\n    for column_metadata in self._metadata:\n        working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)\n    return working_data\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.apply_missingness_strategy","title":"apply_missingness_strategy()","text":"

Resolves missingness in the dataset via the MetaTransformer's global missingness strategy or column-wise missingness strategies. In the case of the AugmentMissingnessStrategy, the missingness is not resolved, instead a new column / value is added for later transformation.

Returns:

Type Description DataFrame

The dataset with the missingness strategies applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_missingness_strategy(self) -> pd.DataFrame:\n    \"\"\"\n    Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or\n    column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness\n    is not resolved, instead a new column / value is added for later transformation.\n\n    Returns:\n        The dataset with the missingness strategies applied.\n    \"\"\"\n    working_data = self.typed_dataset.copy()\n    for column_metadata in self._metadata:\n        if not column_metadata.missingness_strategy:\n            column_metadata.missingness_strategy = (\n                self._missingness_strategy(self._impute_value)\n                if hasattr(self, \"_impute_value\")\n                else self._missingness_strategy()\n            )\n        if not working_data[column_metadata.name].isnull().any():\n            continue\n        working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)\n    return working_data\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.drop_columns","title":"drop_columns()","text":"

Drops columns from the dataset that are not in the MetaData.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def drop_columns(self) -> None:\n    \"\"\"\n    Drops columns from the dataset that are not in the `MetaData`.\n    \"\"\"\n    self._raw_dataset = self._raw_dataset[self._metadata.columns]\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.from_dict","title":"from_dict(dataset, metadata, **kwargs) classmethod","text":"

Instantiates a MetaTransformer from a metadata dictionary.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata dict

A dictionary of raw metadata.

required

Returns:

Type Description Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod\ndef from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:\n    \"\"\"\n    Instantiates a MetaTransformer from a metadata dictionary.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata: A dictionary of raw metadata.\n\n    Returns:\n        A MetaTransformer object.\n    \"\"\"\n    return cls(dataset, MetaData(dataset, metadata), **kwargs)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.from_path","title":"from_path(dataset, metadata_path, **kwargs) classmethod","text":"

Instantiates a MetaTransformer from a metadata file via a provided path.

Parameters:

Name Type Description Default dataset DataFrame

The raw input DataFrame.

required metadata_path str

The path to the metadata file.

required

Returns:

Type Description Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod\ndef from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:\n    \"\"\"\n    Instantiates a MetaTransformer from a metadata file via a provided path.\n\n    Args:\n        dataset: The raw input DataFrame.\n        metadata_path: The path to the metadata file.\n\n    Returns:\n        A MetaTransformer object.\n    \"\"\"\n    return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.get_multi_and_single_column_indices","title":"get_multi_and_single_column_indices()","text":"

Returns the indices of the columns that were transformed into one or multiple column(s).

Returns:

Type Description tuple[list[int], list[int]]

A tuple containing the indices of the single and multi columns.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:\n    \"\"\"\n    Returns the indices of the columns that were transformed into one or multiple column(s).\n\n    Returns:\n        A tuple containing the indices of the single and multi columns.\n    \"\"\"\n    if not hasattr(self, \"multi_column_indices\") or not hasattr(self, \"single_column_indices\"):\n        raise ValueError(\n            \"The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first.\"\n        )\n    return self.multi_column_indices, self.single_column_indices\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.get_sdv_metadata","title":"get_sdv_metadata()","text":"

Calls the MetaData method to reformat its contents into the correct format for use with SDMetrics.

Returns:

Type Description dict[str, dict[str, Any]]

The metadata in the correct format for SDMetrics.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:\n    \"\"\"\n    Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.\n\n    Returns:\n        The metadata in the correct format for SDMetrics.\n    \"\"\"\n    return self._metadata.get_sdv_metadata()\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.inverse_apply","title":"inverse_apply(dataset)","text":"

Reverses the transformation applied by the MetaTransformer.

Parameters:

Name Type Description Default dataset DataFrame

The transformed dataset.

required

Returns:

Type Description DataFrame

The original dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Reverses the transformation applied by the MetaTransformer.\n\n    Args:\n        dataset: The transformed dataset.\n\n    Returns:\n        The original dataset.\n    \"\"\"\n    for column_metadata in self._metadata:\n        dataset = column_metadata.transformer.revert(dataset)\n    return self.apply_dtypes(dataset)\n
"},{"location":"reference/modules/dataloader/metatransformer/#nhssynth.modules.dataloader.metatransformer.MetaTransformer.transform","title":"transform()","text":"

Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.

Returns:

Type Description DataFrame

The transformed dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def transform(self) -> pd.DataFrame:\n    \"\"\"\n    Prepares the dataset by applying each of the columns' transformers and recording the indices of the single and multi columns.\n\n    Returns:\n        The transformed dataset.\n    \"\"\"\n    transformed_columns = []\n    self.single_column_indices = []\n    self.multi_column_indices = []\n    col_counter = 0\n    working_data = self.post_missingness_strategy_dataset.copy()\n\n    # iteratively build the transformed df\n    for column_metadata in tqdm(\n        self._metadata, desc=\"Transforming data\", unit=\"column\", total=len(self._metadata.columns)\n    ):\n        missingness_carrier = self._get_missingness_carrier(column_metadata)\n        transformed_data = column_metadata.transformer.apply(\n            working_data[column_metadata.name], missingness_carrier\n        )\n        transformed_columns.append(transformed_data)\n\n        # track single and multi column indices to supply to the model\n        if isinstance(transformed_data, pd.DataFrame) and transformed_data.shape[1] > 1:\n            num_to_add = transformed_data.shape[1]\n            if not column_metadata.categorical:\n                self.single_column_indices.append(col_counter)\n                col_counter += 1\n                num_to_add -= 1\n            self.multi_column_indices.append(list(range(col_counter, col_counter + num_to_add)))\n            col_counter += num_to_add\n        else:\n            self.single_column_indices.append(col_counter)\n            col_counter += 1\n\n    return pd.concat(transformed_columns, axis=1)\n
"},{"location":"reference/modules/dataloader/missingness/","title":"missingness","text":""},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.AugmentMissingnessStrategy","title":"AugmentMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Source code in src/nhssynth/modules/dataloader/missingness.py
class AugmentMissingnessStrategy(GenericMissingnessStrategy):\n    def __init__(self) -> None:\n        super().__init__(\"augment\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata enabling the correct set up of the missingness strategy.\n\n        Returns:\n            The dataset, potentially with a new column representing the missingness for the column added.\n        \"\"\"\n        if column_metadata.categorical:\n            if column_metadata.dtype.kind == \"O\":\n                self.missingness_carrier = column_metadata.name + \"_missing\"\n            else:\n                self.missingness_carrier = data[column_metadata.name].min() - 1\n        else:\n            self.missingness_carrier = column_metadata.name + \"_missing\"\n            data[self.missingness_carrier] = data[column_metadata.name].isnull().astype(int)\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.AugmentMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata enabling the correct set up of the missingness strategy.

required

Returns:

Type Description DataFrame

The dataset, potentially with a new column representing the missingness for the column added.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Impute missingness with the model. To do this we create a new column for continuous features and a new category for categorical features.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata enabling the correct set up of the missingness strategy.\n\n    Returns:\n        The dataset, potentially with a new column representing the missingness for the column added.\n    \"\"\"\n    if column_metadata.categorical:\n        if column_metadata.dtype.kind == \"O\":\n            self.missingness_carrier = column_metadata.name + \"_missing\"\n        else:\n            self.missingness_carrier = data[column_metadata.name].min() - 1\n    else:\n        self.missingness_carrier = column_metadata.name + \"_missing\"\n        data[self.missingness_carrier] = data[column_metadata.name].isnull().astype(int)\n    return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.DropMissingnessStrategy","title":"DropMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Drop missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class DropMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Drop missingness strategy.\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__(\"drop\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Drop rows containing missing values in the appropriate column.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata.\n\n        Returns:\n            The dataset with rows containing missing values in the appropriate column dropped.\n        \"\"\"\n        return data.dropna(subset=[column_metadata.name]).reset_index(drop=True)\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.DropMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Drop rows containing missing values in the appropriate column.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata.

required

Returns:

Type Description DataFrame

The dataset with rows containing missing values in the appropriate column dropped.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Drop rows containing missing values in the appropriate column.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata.\n\n    Returns:\n        The dataset with rows containing missing values in the appropriate column dropped.\n    \"\"\"\n    return data.dropna(subset=[column_metadata.name]).reset_index(drop=True)\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.GenericMissingnessStrategy","title":"GenericMissingnessStrategy","text":"

Bases: ABC

Generic missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class GenericMissingnessStrategy(ABC):\n    \"\"\"Generic missingness strategy.\"\"\"\n\n    def __init__(self, name: str) -> None:\n        super().__init__()\n        self.name: str = name\n\n    @abstractmethod\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"Remove missingness.\"\"\"\n        pass\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.GenericMissingnessStrategy.remove","title":"remove(data, column_metadata) abstractmethod","text":"

Remove missingness.

Source code in src/nhssynth/modules/dataloader/missingness.py
@abstractmethod\ndef remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"Remove missingness.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.ImputeMissingnessStrategy","title":"ImputeMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Impute missingness with mean strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class ImputeMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Impute missingness with mean strategy.\"\"\"\n\n    def __init__(self, impute: Any) -> None:\n        super().__init__(\"impute\")\n        self.impute = impute.lower() if isinstance(impute, str) else impute\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"\n        Impute missingness in the data via the `impute` strategy. 'Special' values trigger specific behaviour.\n\n        Args:\n            data: The dataset.\n            column_metadata: The column metadata.\n\n        Returns:\n            The dataset with missing values in the appropriate column replaced with imputed ones.\n        \"\"\"\n        if (self.impute == \"mean\" or self.impute == \"median\") and column_metadata.categorical:\n            warnings.warn(\"Cannot impute mean or median for categorical data, using mode instead.\")\n            self.imputation_value = data[column_metadata.name].mode()[0]\n        elif self.impute == \"mean\":\n            self.imputation_value = data[column_metadata.name].mean()\n        elif self.impute == \"median\":\n            self.imputation_value = data[column_metadata.name].median()\n        elif self.impute == \"mode\":\n            self.imputation_value = data[column_metadata.name].mode()[0]\n        else:\n            self.imputation_value = self.impute\n        self.imputation_value = column_metadata.dtype.type(self.imputation_value)\n        try:\n            data[column_metadata.name].fillna(self.imputation_value, inplace=True)\n        except AssertionError:\n            raise ValueError(f\"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.\")\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.ImputeMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Impute missingness in the data via the impute strategy. 'Special' values trigger specific behaviour.

Parameters:

Name Type Description Default data DataFrame

The dataset.

required column_metadata ColumnMetaData

The column metadata.

required

Returns:

Type Description DataFrame

The dataset with missing values in the appropriate column replaced with imputed ones.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"\n    Impute missingness in the data via the `impute` strategy. 'Special' values trigger specific behaviour.\n\n    Args:\n        data: The dataset.\n        column_metadata: The column metadata.\n\n    Returns:\n        The dataset with missing values in the appropriate column replaced with imputed ones.\n    \"\"\"\n    if (self.impute == \"mean\" or self.impute == \"median\") and column_metadata.categorical:\n        warnings.warn(\"Cannot impute mean or median for categorical data, using mode instead.\")\n        self.imputation_value = data[column_metadata.name].mode()[0]\n    elif self.impute == \"mean\":\n        self.imputation_value = data[column_metadata.name].mean()\n    elif self.impute == \"median\":\n        self.imputation_value = data[column_metadata.name].median()\n    elif self.impute == \"mode\":\n        self.imputation_value = data[column_metadata.name].mode()[0]\n    else:\n        self.imputation_value = self.impute\n    self.imputation_value = column_metadata.dtype.type(self.imputation_value)\n    try:\n        data[column_metadata.name].fillna(self.imputation_value, inplace=True)\n    except AssertionError:\n        raise ValueError(f\"Could not impute '{self.imputation_value}' into column: '{column_metadata.name}'.\")\n    return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.NullMissingnessStrategy","title":"NullMissingnessStrategy","text":"

Bases: GenericMissingnessStrategy

Null missingness strategy.

Source code in src/nhssynth/modules/dataloader/missingness.py
class NullMissingnessStrategy(GenericMissingnessStrategy):\n    \"\"\"Null missingness strategy.\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__(\"none\")\n\n    def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n        \"\"\"Do nothing.\"\"\"\n        return data\n
"},{"location":"reference/modules/dataloader/missingness/#nhssynth.modules.dataloader.missingness.NullMissingnessStrategy.remove","title":"remove(data, column_metadata)","text":"

Do nothing.

Source code in src/nhssynth/modules/dataloader/missingness.py
def remove(self, data: pd.DataFrame, column_metadata: ColumnMetaData) -> pd.DataFrame:\n    \"\"\"Do nothing.\"\"\"\n    return data\n
"},{"location":"reference/modules/dataloader/run/","title":"run","text":""},{"location":"reference/modules/dataloader/transformers/","title":"transformers","text":""},{"location":"reference/modules/dataloader/transformers/base/","title":"base","text":""},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer","title":"ColumnTransformer","text":"

Bases: ABC

A generic column transformer class to prototype all of the transformers applied via the MetaTransformer.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
class ColumnTransformer(ABC):\n    \"\"\"A generic column transformer class to prototype all of the transformers applied via the [`MetaTransformer`][nhssynth.modules.dataloader.metatransformer.MetaTransformer].\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    @abstractmethod\n    def apply(self, data: pd.DataFrame, missingness_column: Optional[pd.Series]) -> None:\n        \"\"\"Apply the transformer to the data.\"\"\"\n        pass\n\n    @abstractmethod\n    def revert(self, data: pd.DataFrame) -> None:\n        \"\"\"Revert data to pre-transformer state.\"\"\"\n        pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer.apply","title":"apply(data, missingness_column) abstractmethod","text":"

Apply the transformer to the data.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
@abstractmethod\ndef apply(self, data: pd.DataFrame, missingness_column: Optional[pd.Series]) -> None:\n    \"\"\"Apply the transformer to the data.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.ColumnTransformer.revert","title":"revert(data) abstractmethod","text":"

Revert data to pre-transformer state.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
@abstractmethod\ndef revert(self, data: pd.DataFrame) -> None:\n    \"\"\"Revert data to pre-transformer state.\"\"\"\n    pass\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper","title":"TransformerWrapper","text":"

Bases: ABC

A class to facilitate nesting of ColumnTransformers.

Parameters:

Name Type Description Default wrapped_transformer ColumnTransformer

The ColumnTransformer to wrap.

required Source code in src/nhssynth/modules/dataloader/transformers/base.py
class TransformerWrapper(ABC):\n    \"\"\"\n    A class to facilitate nesting of [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer]s.\n\n    Args:\n        wrapped_transformer: The [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer] to wrap.\n    \"\"\"\n\n    def __init__(self, wrapped_transformer: ColumnTransformer) -> None:\n        super().__init__()\n        self._wrapped_transformer: ColumnTransformer = wrapped_transformer\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series], **kwargs) -> pd.DataFrame:\n        \"\"\"Method for applying the wrapped transformer to the data.\"\"\"\n        return self._wrapped_transformer.apply(data, missingness_column, **kwargs)\n\n    def revert(self, data: pd.Series, **kwargs) -> pd.DataFrame:\n        \"\"\"Method for reverting the passed data via the wrapped transformer.\"\"\"\n        return self._wrapped_transformer.revert(data, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper.apply","title":"apply(data, missingness_column, **kwargs)","text":"

Method for applying the wrapped transformer to the data.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series], **kwargs) -> pd.DataFrame:\n    \"\"\"Method for applying the wrapped transformer to the data.\"\"\"\n    return self._wrapped_transformer.apply(data, missingness_column, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/base/#nhssynth.modules.dataloader.transformers.base.TransformerWrapper.revert","title":"revert(data, **kwargs)","text":"

Method for reverting the passed data via the wrapped transformer.

Source code in src/nhssynth/modules/dataloader/transformers/base.py
def revert(self, data: pd.Series, **kwargs) -> pd.DataFrame:\n    \"\"\"Method for reverting the passed data via the wrapped transformer.\"\"\"\n    return self._wrapped_transformer.revert(data, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/categorical/","title":"categorical","text":""},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer","title":"OHECategoricalTransformer","text":"

Bases: ColumnTransformer

A transformer to one-hot encode categorical features via sklearn's OneHotEncoder. Essentially wraps the fit_transformer and inverse_transform methods of OneHotEncoder to comply with the ColumnTransformer interface.

Parameters:

Name Type Description Default drop Optional[Union[list, str]]

str or list of str, to pass to OneHotEncoder's drop parameter.

None

Attributes:

Name Type Description missing_value Any

The value used to fill missing values in the data.

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description original_column_name

The name of the original column.

new_column_names

The names of the columns generated by the transformer.

Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
class OHECategoricalTransformer(ColumnTransformer):\n    \"\"\"\n    A transformer to one-hot encode categorical features via sklearn's `OneHotEncoder`.\n    Essentially wraps the `fit_transformer` and `inverse_transform` methods of `OneHotEncoder` to comply with the `ColumnTransformer` interface.\n\n    Args:\n        drop: str or list of str, to pass to `OneHotEncoder`'s `drop` parameter.\n\n    Attributes:\n        missing_value: The value used to fill missing values in the data.\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        original_column_name: The name of the original column.\n        new_column_names: The names of the columns generated by the transformer.\n    \"\"\"\n\n    def __init__(self, drop: Optional[Union[list, str]] = None) -> None:\n        super().__init__()\n        self._drop: Union[list, str] = drop\n        self._transformer: OneHotEncoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=self._drop)\n        self.missing_value: Any = None\n\n    def apply(self, data: pd.Series, missing_value: Optional[Any] = None) -> pd.DataFrame:\n        \"\"\"\n        Apply the transformer to the data via sklearn's `OneHotEncoder`'s `fit_transform` method. Name the new columns via manipulation of the original column name.\n        If `missing_value` is provided, fill missing values with this value before applying the transformer to ensure a new category is added.\n\n        Args:\n            data: The column of data to transform.\n            missing_value: The value learned by the `MetaTransformer` to represent missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n        \"\"\"\n        self.original_column_name = data.name\n        if missing_value:\n            data = data.fillna(missing_value)\n            self.missing_value = missing_value\n        transformed_data = pd.DataFrame(\n            self._transformer.fit_transform(data.values.reshape(-1, 1)),\n            columns=self._transformer.get_feature_names_out(input_features=[data.name]),\n        )\n        self.new_column_names = transformed_data.columns\n        return transformed_data\n\n    def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Revert data to pre-transformer state via sklearn's `OneHotEncoder`'s `inverse_transform` method.\n        If `missing_value` is provided, replace instances of this value in the data with `np.nan` to ensure missing values are represented correctly in the case\n        where `missing_value` was 'modelled' and thus generated.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.\n        \"\"\"\n        data[self.original_column_name] = pd.Series(\n            self._transformer.inverse_transform(data[self.new_column_names].values).flatten(),\n            index=data.index,\n            name=self.original_column_name,\n        )\n        if self.missing_value:\n            data[self.original_column_name] = data[self.original_column_name].replace(self.missing_value, np.nan)\n        return data.drop(self.new_column_names, axis=1)\n
"},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer.apply","title":"apply(data, missing_value=None)","text":"

Apply the transformer to the data via sklearn's OneHotEncoder's fit_transform method. Name the new columns via manipulation of the original column name. If missing_value is provided, fill missing values with this value before applying the transformer to ensure a new category is added.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missing_value Optional[Any]

The value learned by the MetaTransformer to represent missingness, this is only used as part of the AugmentMissingnessStrategy.

None Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
def apply(self, data: pd.Series, missing_value: Optional[Any] = None) -> pd.DataFrame:\n    \"\"\"\n    Apply the transformer to the data via sklearn's `OneHotEncoder`'s `fit_transform` method. Name the new columns via manipulation of the original column name.\n    If `missing_value` is provided, fill missing values with this value before applying the transformer to ensure a new category is added.\n\n    Args:\n        data: The column of data to transform.\n        missing_value: The value learned by the `MetaTransformer` to represent missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n    \"\"\"\n    self.original_column_name = data.name\n    if missing_value:\n        data = data.fillna(missing_value)\n        self.missing_value = missing_value\n    transformed_data = pd.DataFrame(\n        self._transformer.fit_transform(data.values.reshape(-1, 1)),\n        columns=self._transformer.get_feature_names_out(input_features=[data.name]),\n    )\n    self.new_column_names = transformed_data.columns\n    return transformed_data\n
"},{"location":"reference/modules/dataloader/transformers/categorical/#nhssynth.modules.dataloader.transformers.categorical.OHECategoricalTransformer.revert","title":"revert(data)","text":"

Revert data to pre-transformer state via sklearn's OneHotEncoder's inverse_transform method. If missing_value is provided, replace instances of this value in the data with np.nan to ensure missing values are represented correctly in the case where missing_value was 'modelled' and thus generated.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.

Source code in src/nhssynth/modules/dataloader/transformers/categorical.py
def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Revert data to pre-transformer state via sklearn's `OneHotEncoder`'s `inverse_transform` method.\n    If `missing_value` is provided, replace instances of this value in the data with `np.nan` to ensure missing values are represented correctly in the case\n    where `missing_value` was 'modelled' and thus generated.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The dataset with a single categorical column that is analogous to the original column, with the same name, and without the generated one-hot columns.\n    \"\"\"\n    data[self.original_column_name] = pd.Series(\n        self._transformer.inverse_transform(data[self.new_column_names].values).flatten(),\n        index=data.index,\n        name=self.original_column_name,\n    )\n    if self.missing_value:\n        data[self.original_column_name] = data[self.original_column_name].replace(self.missing_value, np.nan)\n    return data.drop(self.new_column_names, axis=1)\n
"},{"location":"reference/modules/dataloader/transformers/continuous/","title":"continuous","text":""},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer","title":"ClusterContinuousTransformer","text":"

Bases: ColumnTransformer

A transformer to cluster continuous features via sklearn's BayesianGaussianMixture. Essentially wraps the process of fitting the BGM model and generating cluster assignments and normalised values for the data to comply with the ColumnTransformer interface.

Parameters:

Name Type Description Default n_components int

The number of components to use in the BGM model.

10 n_init int

The number of initialisations to use in the BGM model.

1 init_params str

The initialisation method to use in the BGM model.

'kmeans' random_state int

The random state to use in the BGM model.

0 max_iter int

The maximum number of iterations to use in the BGM model.

1000 remove_unused_components bool

Whether to remove components that have no data assigned EXPERIMENTAL.

False clip_output bool

Whether to clip the output normalised values to the range [-1, 1].

False

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description means

The means of the components in the BGM model.

stds

The standard deviations of the components in the BGM model.

new_column_names

The names of the columns generated by the transformer (one for the normalised values and one for each cluster component).

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
class ClusterContinuousTransformer(ColumnTransformer):\n    \"\"\"\n    A transformer to cluster continuous features via sklearn's `BayesianGaussianMixture`.\n    Essentially wraps the process of fitting the BGM model and generating cluster assignments and normalised values for the data to comply with the `ColumnTransformer` interface.\n\n    Args:\n        n_components: The number of components to use in the BGM model.\n        n_init: The number of initialisations to use in the BGM model.\n        init_params: The initialisation method to use in the BGM model.\n        random_state: The random state to use in the BGM model.\n        max_iter: The maximum number of iterations to use in the BGM model.\n        remove_unused_components: Whether to remove components that have no data assigned EXPERIMENTAL.\n        clip_output: Whether to clip the output normalised values to the range [-1, 1].\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        means: The means of the components in the BGM model.\n        stds: The standard deviations of the components in the BGM model.\n        new_column_names: The names of the columns generated by the transformer (one for the normalised values and one for each cluster component).\n    \"\"\"\n\n    def __init__(\n        self,\n        n_components: int = 10,\n        n_init: int = 1,\n        init_params: str = \"kmeans\",\n        random_state: int = 0,\n        max_iter: int = 1000,\n        remove_unused_components: bool = False,\n        clip_output: bool = False,\n    ) -> None:\n        super().__init__()\n        self._transformer = BayesianGaussianMixture(\n            n_components=n_components,\n            random_state=random_state,\n            n_init=n_init,\n            init_params=init_params,\n            max_iter=max_iter,\n            weight_concentration_prior=1e-3,\n        )\n        self._n_components = n_components\n        self._std_multiplier = 4\n        self._missingness_column_name = None\n        self._max_iter = max_iter\n        self.remove_unused_components = remove_unused_components\n        self.clip_output = clip_output\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None) -> pd.DataFrame:\n        \"\"\"\n        Apply the transformer to the data via sklearn's `BayesianGaussianMixture`'s `fit` and `predict_proba` methods.\n        Name the new columns via the original column name.\n\n        If `missingness_column` is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0\n        (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.\n\n        Args:\n            data: The column of data to transform.\n            missingness_column: The column of data representing missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n\n        Returns:\n            The transformed data (will be multiple columns if `n_components` > 1 at initialisation).\n        \"\"\"\n        self.original_column_name = data.name\n        if missingness_column is not None:\n            self._missingness_column_name = missingness_column.name\n            full_index = data.index\n            data = data[missingness_column == 0]\n        index = data.index\n        data = np.array(data.values.reshape(-1, 1), dtype=data.dtype.name.lower())\n\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n            self._transformer.fit(data)\n\n        self.means = self._transformer.means_.reshape(-1)\n        self.stds = np.sqrt(self._transformer.covariances_).reshape(-1)\n\n        components = np.argmax(self._transformer.predict_proba(data), axis=1)\n        normalised_values = (data - self.means.reshape(1, -1)) / (self._std_multiplier * self.stds.reshape(1, -1))\n        normalised = normalised_values[np.arange(len(data)), components]\n        normalised = np.clip(normalised, -1.0, 1.0)\n        components = np.eye(self._n_components, dtype=int)[components]\n\n        transformed_data = pd.DataFrame(\n            np.hstack([normalised.reshape(-1, 1), components]),\n            index=index,\n            columns=[f\"{self.original_column_name}_normalised\"]\n            + [f\"{self.original_column_name}_c{i + 1}\" for i in range(self._n_components)],\n        )\n\n        # EXPERIMENTAL feature, removing components from the column matrix that have no data assigned to them\n        if self.remove_unused_components:\n            nunique = transformed_data.iloc[:, 1:].nunique(dropna=False)\n            unused_components = nunique[nunique == 1].index\n            unused_component_idx = [transformed_data.columns.get_loc(col_name) - 1 for col_name in unused_components]\n            self.means = np.delete(self.means, unused_component_idx)\n            self.stds = np.delete(self.stds, unused_component_idx)\n            transformed_data.drop(unused_components, axis=1, inplace=True)\n\n        if missingness_column is not None:\n            transformed_data = pd.concat([transformed_data.reindex(full_index).fillna(0.0), missingness_column], axis=1)\n\n        self.new_column_names = transformed_data.columns\n        return transformed_data.astype(\n            {col_name: int for col_name in transformed_data.columns if re.search(r\"_c\\d+\", col_name)}\n        )\n\n    def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n        \"\"\"\n        Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the `new_column_names` attribute.\n        If `missingness_column` was provided to the `apply` method, drop the missing values from the data before reverting and use the `full_index` to\n        reintroduce missing values when `original_column_name` is constructed.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.\n        \"\"\"\n        working_data = data[self.new_column_names]\n        full_index = working_data.index\n        if self._missingness_column_name is not None:\n            working_data = working_data[working_data[self._missingness_column_name] == 0]\n            working_data = working_data.drop(self._missingness_column_name, axis=1)\n        index = working_data.index\n\n        components = np.argmax(working_data.filter(regex=r\".*_c\\d+\").values, axis=1)\n        working_data = working_data.filter(like=\"_normalised\").values.reshape(-1)\n        if self.clip_output:\n            working_data = np.clip(working_data, -1.0, 1.0)\n\n        mean_t = self.means[components]\n        std_t = self.stds[components]\n        data[self.original_column_name] = pd.Series(\n            working_data * self._std_multiplier * std_t + mean_t, index=index, name=self.original_column_name\n        ).reindex(full_index)\n        data.drop(self.new_column_names, axis=1, inplace=True)\n        return data\n
"},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer.apply","title":"apply(data, missingness_column=None)","text":"

Apply the transformer to the data via sklearn's BayesianGaussianMixture's fit and predict_proba methods. Name the new columns via the original column name.

If missingness_column is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0 (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missingness_column Optional[Series]

The column of data representing missingness, this is only used as part of the AugmentMissingnessStrategy.

None

Returns:

Type Description DataFrame

The transformed data (will be multiple columns if n_components > 1 at initialisation).

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None) -> pd.DataFrame:\n    \"\"\"\n    Apply the transformer to the data via sklearn's `BayesianGaussianMixture`'s `fit` and `predict_proba` methods.\n    Name the new columns via the original column name.\n\n    If `missingness_column` is provided, use this to extract the non-missing data; the missing values are assigned to a new pseudo-cluster with mean 0\n    (i.e. all values in the normalised column are 0.0). We do this by taking the full index before subsetting to non-missing data, then reindexing.\n\n    Args:\n        data: The column of data to transform.\n        missingness_column: The column of data representing missingness, this is only used as part of the `AugmentMissingnessStrategy`.\n\n    Returns:\n        The transformed data (will be multiple columns if `n_components` > 1 at initialisation).\n    \"\"\"\n    self.original_column_name = data.name\n    if missingness_column is not None:\n        self._missingness_column_name = missingness_column.name\n        full_index = data.index\n        data = data[missingness_column == 0]\n    index = data.index\n    data = np.array(data.values.reshape(-1, 1), dtype=data.dtype.name.lower())\n\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n        self._transformer.fit(data)\n\n    self.means = self._transformer.means_.reshape(-1)\n    self.stds = np.sqrt(self._transformer.covariances_).reshape(-1)\n\n    components = np.argmax(self._transformer.predict_proba(data), axis=1)\n    normalised_values = (data - self.means.reshape(1, -1)) / (self._std_multiplier * self.stds.reshape(1, -1))\n    normalised = normalised_values[np.arange(len(data)), components]\n    normalised = np.clip(normalised, -1.0, 1.0)\n    components = np.eye(self._n_components, dtype=int)[components]\n\n    transformed_data = pd.DataFrame(\n        np.hstack([normalised.reshape(-1, 1), components]),\n        index=index,\n        columns=[f\"{self.original_column_name}_normalised\"]\n        + [f\"{self.original_column_name}_c{i + 1}\" for i in range(self._n_components)],\n    )\n\n    # EXPERIMENTAL feature, removing components from the column matrix that have no data assigned to them\n    if self.remove_unused_components:\n        nunique = transformed_data.iloc[:, 1:].nunique(dropna=False)\n        unused_components = nunique[nunique == 1].index\n        unused_component_idx = [transformed_data.columns.get_loc(col_name) - 1 for col_name in unused_components]\n        self.means = np.delete(self.means, unused_component_idx)\n        self.stds = np.delete(self.stds, unused_component_idx)\n        transformed_data.drop(unused_components, axis=1, inplace=True)\n\n    if missingness_column is not None:\n        transformed_data = pd.concat([transformed_data.reindex(full_index).fillna(0.0), missingness_column], axis=1)\n\n    self.new_column_names = transformed_data.columns\n    return transformed_data.astype(\n        {col_name: int for col_name in transformed_data.columns if re.search(r\"_c\\d+\", col_name)}\n    )\n
"},{"location":"reference/modules/dataloader/transformers/continuous/#nhssynth.modules.dataloader.transformers.continuous.ClusterContinuousTransformer.revert","title":"revert(data)","text":"

Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the new_column_names attribute. If missingness_column was provided to the apply method, drop the missing values from the data before reverting and use the full_index to reintroduce missing values when original_column_name is constructed.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.

Source code in src/nhssynth/modules/dataloader/transformers/continuous.py
def revert(self, data: pd.DataFrame) -> pd.DataFrame:\n    \"\"\"\n    Revert data to pre-transformer state via the means and stds of the BGM. Extract the relevant columns from the data via the `new_column_names` attribute.\n    If `missingness_column` was provided to the `apply` method, drop the missing values from the data before reverting and use the `full_index` to\n    reintroduce missing values when `original_column_name` is constructed.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The dataset with a single continuous column that is analogous to the original column, with the same name, and without the generated columns from which it is derived.\n    \"\"\"\n    working_data = data[self.new_column_names]\n    full_index = working_data.index\n    if self._missingness_column_name is not None:\n        working_data = working_data[working_data[self._missingness_column_name] == 0]\n        working_data = working_data.drop(self._missingness_column_name, axis=1)\n    index = working_data.index\n\n    components = np.argmax(working_data.filter(regex=r\".*_c\\d+\").values, axis=1)\n    working_data = working_data.filter(like=\"_normalised\").values.reshape(-1)\n    if self.clip_output:\n        working_data = np.clip(working_data, -1.0, 1.0)\n\n    mean_t = self.means[components]\n    std_t = self.stds[components]\n    data[self.original_column_name] = pd.Series(\n        working_data * self._std_multiplier * std_t + mean_t, index=index, name=self.original_column_name\n    ).reindex(full_index)\n    data.drop(self.new_column_names, axis=1, inplace=True)\n    return data\n
"},{"location":"reference/modules/dataloader/transformers/datetime/","title":"datetime","text":""},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer","title":"DatetimeTransformer","text":"

Bases: TransformerWrapper

A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer. The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the AugmentMissingnessStrategy.

Parameters:

Name Type Description Default transformer ColumnTransformer

The ColumnTransformer to wrap.

required

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description original_column_name

The name of the original column.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
class DatetimeTransformer(TransformerWrapper):\n    \"\"\"\n    A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer.\n    The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the `AugmentMissingnessStrategy`.\n\n    Args:\n        transformer: The [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer] to wrap.\n\n    After applying the transformer, the following attributes will be populated:\n\n    Attributes:\n        original_column_name: The name of the original column.\n    \"\"\"\n\n    def __init__(self, transformer: ColumnTransformer) -> None:\n        super().__init__(transformer)\n\n    def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None, **kwargs) -> pd.DataFrame:\n        \"\"\"\n        Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch.\n        The float value of `pd.NaT` under the operation above is then replaced with `np.nan` to ensure missing values are represented correctly.\n        Finally, the wrapped transformer is applied to the data.\n\n        Args:\n            data: The column of data to transform.\n            missingness_column: The column of missingness indicators to augment the data with.\n\n        Returns:\n            The transformed data.\n        \"\"\"\n        self.original_column_name = data.name\n        floored_data = pd.Series(data.dt.floor(\"ns\").to_numpy().astype(float), name=data.name)\n        nan_corrected_data = floored_data.replace(pd.to_datetime(pd.NaT).to_numpy().astype(float), np.nan)\n        return super().apply(nan_corrected_data, missingness_column, **kwargs)\n\n    def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:\n        \"\"\"\n        The wrapped transformer's `revert` method is applied to the data. The data is then converted back to datetime format.\n\n        Args:\n            data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n        Returns:\n            The reverted data.\n        \"\"\"\n        reverted_data = super().revert(data, **kwargs)\n        data[self.original_column_name] = pd.to_datetime(\n            reverted_data[self.original_column_name].astype(\"Int64\"), unit=\"ns\"\n        )\n        return data\n
"},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer.apply","title":"apply(data, missingness_column=None, **kwargs)","text":"

Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch. The float value of pd.NaT under the operation above is then replaced with np.nan to ensure missing values are represented correctly. Finally, the wrapped transformer is applied to the data.

Parameters:

Name Type Description Default data Series

The column of data to transform.

required missingness_column Optional[Series]

The column of missingness indicators to augment the data with.

None

Returns:

Type Description DataFrame

The transformed data.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def apply(self, data: pd.Series, missingness_column: Optional[pd.Series] = None, **kwargs) -> pd.DataFrame:\n    \"\"\"\n    Firstly, the datetime data is floored to the nano-second level. Next, the floored data is converted to float nanoseconds since the epoch.\n    The float value of `pd.NaT` under the operation above is then replaced with `np.nan` to ensure missing values are represented correctly.\n    Finally, the wrapped transformer is applied to the data.\n\n    Args:\n        data: The column of data to transform.\n        missingness_column: The column of missingness indicators to augment the data with.\n\n    Returns:\n        The transformed data.\n    \"\"\"\n    self.original_column_name = data.name\n    floored_data = pd.Series(data.dt.floor(\"ns\").to_numpy().astype(float), name=data.name)\n    nan_corrected_data = floored_data.replace(pd.to_datetime(pd.NaT).to_numpy().astype(float), np.nan)\n    return super().apply(nan_corrected_data, missingness_column, **kwargs)\n
"},{"location":"reference/modules/dataloader/transformers/datetime/#nhssynth.modules.dataloader.transformers.datetime.DatetimeTransformer.revert","title":"revert(data, **kwargs)","text":"

The wrapped transformer's revert method is applied to the data. The data is then converted back to datetime format.

Parameters:

Name Type Description Default data DataFrame

The full dataset including the column(s) to be reverted to their pre-transformer state.

required

Returns:

Type Description DataFrame

The reverted data.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:\n    \"\"\"\n    The wrapped transformer's `revert` method is applied to the data. The data is then converted back to datetime format.\n\n    Args:\n        data: The full dataset including the column(s) to be reverted to their pre-transformer state.\n\n    Returns:\n        The reverted data.\n    \"\"\"\n    reverted_data = super().revert(data, **kwargs)\n    data[self.original_column_name] = pd.to_datetime(\n        reverted_data[self.original_column_name].astype(\"Int64\"), unit=\"ns\"\n    )\n    return data\n
"},{"location":"reference/modules/evaluation/","title":"evaluation","text":""},{"location":"reference/modules/evaluation/aequitas/","title":"aequitas","text":""},{"location":"reference/modules/evaluation/io/","title":"io","text":""},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_typed str

The name of the typed real dataset file.

required fn_synthetic_datasets str

The filename of the collection of synethtic datasets.

required fn_sdv_metadata str

The name of the SDV metadata file.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/evaluation/io.py
def check_input_paths(\n    fn_dataset: str, fn_typed: str, fn_synthetic_datasets: str, fn_sdv_metadata: str, dir_experiment: Path\n) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_typed: The name of the typed real dataset file.\n        fn_synthetic_datasets: The filename of the collection of synethtic datasets.\n        fn_sdv_metadata: The name of the SDV metadata file.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_typed, fn_synthetic_datasets, fn_sdv_metadata = io.consistent_endings(\n        [fn_typed, fn_synthetic_datasets, fn_sdv_metadata]\n    )\n    fn_typed, fn_synthetic_datasets, fn_sdv_metadata = io.potential_suffixes(\n        [fn_typed, fn_synthetic_datasets, fn_sdv_metadata], fn_dataset\n    )\n    io.warn_if_path_supplied([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)\n    io.check_exists([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)\n    return fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata\n
"},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, DataFrame, dict[str, dict[str, Any]]]

The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata.

Source code in src/nhssynth/modules/evaluation/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, pd.DataFrame, dict[str, dict[str, Any]]]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"typed\", \"synthetic_datasets\", \"sdv_metadata\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"typed\"],\n            args.module_handover[\"synthetic_datasets\"],\n            args.module_handover[\"sdv_metadata\"],\n        )\n    else:\n        fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata = check_input_paths(\n            args.dataset, args.typed, args.synthetic_datasets, args.sdv_metadata, dir_experiment\n        )\n        with open(dir_experiment / fn_typed, \"rb\") as f:\n            real_data = pickle.load(f).contents\n        with open(dir_experiment / fn_sdv_metadata, \"rb\") as f:\n            sdv_metadata = pickle.load(f)\n        with open(dir_experiment / fn_synthetic_datasets, \"rb\") as f:\n            synthetic_datasets = pickle.load(f).contents\n\n        return fn_dataset, real_data, synthetic_datasets, sdv_metadata\n
"},{"location":"reference/modules/evaluation/io/#nhssynth.modules.evaluation.io.output_eval","title":"output_eval(evaluations, fn_dataset, fn_evaluations, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default evaluations DataFrame

The evaluations to output.

required fn_dataset Path

The base name of the dataset.

required fn_evaluations str

The filename of the collection of evaluations.

required dir_experiment Path

The path to the experiment output directory.

required

Returns:

Type Description None

The path to output the model.

Source code in src/nhssynth/modules/evaluation/io.py
def output_eval(\n    evaluations: pd.DataFrame,\n    fn_dataset: Path,\n    fn_evaluations: str,\n    dir_experiment: Path,\n) -> None:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        evaluations: The evaluations to output.\n        fn_dataset: The base name of the dataset.\n        fn_evaluations: The filename of the collection of evaluations.\n        dir_experiment: The path to the experiment output directory.\n\n    Returns:\n        The path to output the model.\n    \"\"\"\n    fn_evaluations = io.consistent_ending(fn_evaluations)\n    fn_evaluations = io.potential_suffix(fn_evaluations, fn_dataset)\n    io.warn_if_path_supplied([fn_evaluations], dir_experiment)\n    with open(dir_experiment / fn_evaluations, \"wb\") as f:\n        pickle.dump(Evaluations(evaluations), f)\n
"},{"location":"reference/modules/evaluation/metrics/","title":"metrics","text":""},{"location":"reference/modules/evaluation/run/","title":"run","text":""},{"location":"reference/modules/evaluation/tasks/","title":"tasks","text":""},{"location":"reference/modules/evaluation/tasks/#nhssynth.modules.evaluation.tasks.Task","title":"Task","text":"

A task offers a light-touch way for users to specify any arbitrary downstream task that they want to run on a dataset.

Parameters:

Name Type Description Default name str

The name of the task.

required run Callable

The function to run.

required supports_aequitas

Whether the task supports Aequitas evaluation.

False description str

The description of the task.

'' Source code in src/nhssynth/modules/evaluation/tasks.py
class Task:\n    \"\"\"\n    A task offers a light-touch way for users to specify any arbitrary downstream task that they want to run on a dataset.\n\n    Args:\n        name: The name of the task.\n        run: The function to run.\n        supports_aequitas: Whether the task supports Aequitas evaluation.\n        description: The description of the task.\n    \"\"\"\n\n    def __init__(self, name: str, run: Callable, supports_aequitas=False, description: str = \"\"):\n        self._name: str = name\n        self._run: Callable = run\n        self._supports_aequitas: bool = supports_aequitas\n        self._description: str = description\n\n    def __str__(self) -> str:\n        return f\"{self.name}: {self.description}\" if self.description else self.name\n\n    def __repr__(self) -> str:\n        return str([self.name, self.run, self.supports_aequitas, self.description])\n\n    def run(self, *args, **kwargs):\n        return self._run(*args, **kwargs)\n
"},{"location":"reference/modules/evaluation/tasks/#nhssynth.modules.evaluation.tasks.get_tasks","title":"get_tasks(fn_dataset, tasks_root)","text":"

Searches for and imports all tasks in the tasks directory for a given dataset. Uses importlib to extract the task from the file.

Parameters:

Name Type Description Default fn_dataset str

The name of the dataset.

required tasks_root str

The root directory for downstream tasks.

required

Returns:

Type Description list[Task]

A list of tasks.

Source code in src/nhssynth/modules/evaluation/tasks.py
def get_tasks(\n    fn_dataset: str,\n    tasks_root: str,\n) -> list[Task]:\n    \"\"\"\n    Searches for and imports all tasks in the tasks directory for a given dataset.\n    Uses `importlib` to extract the task from the file.\n\n    Args:\n        fn_dataset: The name of the dataset.\n        tasks_root: The root directory for downstream tasks.\n\n    Returns:\n        A list of tasks.\n    \"\"\"\n    tasks_dir = Path(tasks_root) / fn_dataset\n    assert (\n        tasks_dir.exists()\n    ), f\"Downstream tasks directory does not exist ({tasks_dir}), NB there should be a directory in TASKS_DIR with the same name as the dataset.\"\n    tasks = []\n    for task_path in tasks_dir.iterdir():\n        if task_path.name.startswith((\".\", \"__\")):\n            continue\n        assert task_path.suffix == \".py\", f\"Downstream task file must be a python file ({task_path.name})\"\n        spec = importlib.util.spec_from_file_location(\n            \"nhssynth_task_\" + task_path.name, os.getcwd() + \"/\" + str(task_path)\n        )\n        task_module = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(task_module)\n        tasks.append(task_module.task)\n    return tasks\n
"},{"location":"reference/modules/evaluation/utils/","title":"utils","text":""},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame","title":"EvalFrame","text":"

Data structure for specifying and recording the evaluations of a set of synthetic datasets against a real dataset. All of the choices made by the user in the evaluation module are consolidated into this class.

After running evaluate on a set of synthetic datasets, the evaluations can be retrieved using get_evaluations. They are stored in a dict of dataframes with indices matching that of the supplied dataframe of synthetic datasets.

Parameters:

Name Type Description Default tasks list[Task]

A list of downstream tasks to run on the experiments.

required metrics list[str]

A list of metrics to calculate on the experiments.

required sdv_metadata dict[str, dict[str, str]]

The SDV metadata for the dataset.

required aequitas bool

Whether to run Aequitas on the results of supported downstream tasks.

False aequitas_attributes list[str]

The fairness-related attributes to use for Aequitas analysis.

[] key_numerical_fields list[str]

The numerical fields to use for SDV privacy metrics.

[] sensitive_numerical_fields list[str]

The numerical fields to use for SDV privacy metrics.

[] key_categorical_fields list[str]

The categorical fields to use for SDV privacy metrics.

[] sensitive_categorical_fields list[str]

The categorical fields to use for SDV privacy metrics.

[] Source code in src/nhssynth/modules/evaluation/utils.py
class EvalFrame:\n    \"\"\"\n    Data structure for specifying and recording the evaluations of a set of synthetic datasets against a real dataset.\n    All of the choices made by the user in the evaluation module are consolidated into this class.\n\n    After running `evaluate` on a set of synthetic datasets, the evaluations can be retrieved using `get_evaluations`.\n    They are stored in a dict of dataframes with indices matching that of the supplied dataframe of synthetic datasets.\n\n    Args:\n        tasks: A list of downstream tasks to run on the experiments.\n        metrics: A list of metrics to calculate on the experiments.\n        sdv_metadata: The SDV metadata for the dataset.\n        aequitas: Whether to run Aequitas on the results of supported downstream tasks.\n        aequitas_attributes: The fairness-related attributes to use for Aequitas analysis.\n        key_numerical_fields: The numerical fields to use for SDV privacy metrics.\n        sensitive_numerical_fields: The numerical fields to use for SDV privacy metrics.\n        key_categorical_fields: The categorical fields to use for SDV privacy metrics.\n        sensitive_categorical_fields: The categorical fields to use for SDV privacy metrics.\n    \"\"\"\n\n    def __init__(\n        self,\n        tasks: list[Task],\n        metrics: list[str],\n        sdv_metadata: dict[str, dict[str, str]],\n        aequitas: bool = False,\n        aequitas_attributes: list[str] = [],\n        key_numerical_fields: list[str] = [],\n        sensitive_numerical_fields: list[str] = [],\n        key_categorical_fields: list[str] = [],\n        sensitive_categorical_fields: list[str] = [],\n    ):\n        self._tasks = tasks\n        self._aequitas = aequitas\n        self._aequitas_attributes = aequitas_attributes\n\n        self._metrics = metrics\n        self._sdv_metadata = sdv_metadata\n\n        self._key_numerical_fields = key_numerical_fields\n        self._sensitive_numerical_fields = sensitive_numerical_fields\n        self._key_categorical_fields = key_categorical_fields\n        self._sensitive_categorical_fields = sensitive_categorical_fields\n        assert all([metric not in NUMERICAL_PRIVACY_METRICS for metric in self._metrics]) or (\n            self._key_numerical_fields and self._sensitive_numerical_fields\n        ), \"Numerical key and sensitive fields must be provided when an SDV privacy metric is used.\"\n        assert all([metric not in CATEGORICAL_PRIVACY_METRICS for metric in self._metrics]) or (\n            self._key_categorical_fields and self._sensitive_categorical_fields\n        ), \"Categorical key and sensitive fields must be provided when an SDV privacy metric is used.\"\n\n        self._metric_groups = self._build_metric_groups()\n\n    def _build_metric_groups(self) -> list[str]:\n        \"\"\"\n        Iterate through the concatenated list of metrics provided by the user and refer to the\n        [defined metric groups][nhssynth.common.constants] to identify which to evaluate.\n\n        Returns:\n            A list of metric groups to evaluate.\n        \"\"\"\n        metric_groups = set()\n        if self._tasks:\n            metric_groups.add(\"task\")\n        if self._aequitas:\n            metric_groups.add(\"aequitas\")\n        for metric in self._metrics:\n            if metric in TABLE_METRICS:\n                metric_groups.add(\"table\")\n            if metric in NUMERICAL_PRIVACY_METRICS or metric in CATEGORICAL_PRIVACY_METRICS:\n                metric_groups.add(\"privacy\")\n            if metric in TABLE_METRICS and issubclass(TABLE_METRICS[metric], MultiSingleColumnMetric):\n                metric_groups.add(\"columnwise\")\n            if metric in TABLE_METRICS and issubclass(TABLE_METRICS[metric], MultiColumnPairsMetric):\n                metric_groups.add(\"pairwise\")\n        return list(metric_groups)\n\n    def evaluate(self, real_dataset: pd.DataFrame, synthetic_datasets: list[dict[str, Any]]) -> None:\n        \"\"\"\n        Evaluate a set of synthetic datasets against a real dataset.\n\n        Args:\n            real_dataset: The real dataset to evaluate against.\n            synthetic_datasets: The synthetic datasets to evaluate.\n        \"\"\"\n        assert not any(\"Real\" in i for i in synthetic_datasets.index), \"Real is a reserved dataset ID.\"\n        assert synthetic_datasets.index.is_unique, \"Dataset IDs must be unique.\"\n        self._evaluations = pd.DataFrame(index=synthetic_datasets.index, columns=self._metric_groups)\n        self._evaluations.loc[(\"Real\", None, None)] = self._step(real_dataset)\n        pbar = tqdm(synthetic_datasets.iterrows(), desc=\"Evaluating\", total=len(synthetic_datasets))\n        for i, dataset in pbar:\n            pbar.set_description(f\"Evaluating {i[0]}, repeat {i[1]}, config {i[2]}\")\n            self._evaluations.loc[i] = self._step(real_dataset, dataset.values[0])\n\n    def get_evaluations(self) -> dict[str, pd.DataFrame]:\n        \"\"\"\n        Unpack the `self._evaluations` dataframe, where each metric group is a column, into a dict of dataframes.\n\n        Returns:\n            A dict of dataframes, one for each metric group, containing the evaluations.\n        \"\"\"\n        assert hasattr(\n            self, \"_evaluations\"\n        ), \"You must first run `evaluate` on a `real_dataset` and set of `synthetic_datasets`.\"\n        return {\n            metric_group: pd.DataFrame(\n                self._evaluations[metric_group].values.tolist(), index=self._evaluations.index\n            ).dropna(how=\"all\")\n            for metric_group in self._metric_groups\n        }\n\n    def _task_step(self, data: pd.DataFrame) -> dict[str, dict]:\n        \"\"\"\n        Run the downstream tasks on the dataset. Optionally run Aequitas on the results of the tasks.\n\n        Args:\n            data: The dataset to run the tasks on.\n\n        Returns:\n            A dict of dicts, one for each metric group, to be populated with each groups metric values.\n        \"\"\"\n        metric_dict = {metric_group: {} for metric_group in self._metric_groups}\n        for task in tqdm(self._tasks, desc=\"Running downstream tasks\", leave=False):\n            task_pred_column, task_metric_values = task.run(data)\n            metric_dict[\"task\"].update(task_metric_values)\n            if self._aequitas and task.supports_aequitas:\n                metric_dict[\"aequitas\"].update(run_aequitas(data[self._aequitas_attributes].join(task_pred_column)))\n        return metric_dict\n\n    def _compute_metric(\n        self, metric_dict: dict, metric: str, real_data: pd.DataFrame, synthetic_data: pd.DataFrame\n    ) -> dict[str, dict]:\n        \"\"\"\n        Given a metric, determine the correct way to evaluate it via the lists defined in `nhssynth.common.constants`.\n\n        Args:\n            metric_dict: The dict of dicts to populate with metric values.\n            metric: The metric to evaluate.\n            real_data: The real dataset to evaluate against.\n            synthetic_data: The synthetic dataset to evaluate.\n\n        Returns:\n            The metric_dict updated with the value of the metric.\n        \"\"\"\n        with pd.option_context(\"mode.chained_assignment\", None), warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"ConvergenceWarning\")\n            if metric in TABLE_METRICS:\n                metric_dict[\"table\"][metric] = TABLE_METRICS[metric].compute(\n                    real_data, synthetic_data, self._sdv_metadata\n                )\n                if issubclass(TABLE_METRICS[metric], MultiSingleColumnMetric):\n                    metric_dict[\"columnwise\"][metric] = TABLE_METRICS[metric].compute_breakdown(\n                        real_data, synthetic_data, self._sdv_metadata\n                    )\n                elif issubclass(TABLE_METRICS[metric], MultiColumnPairsMetric):\n                    metric_dict[\"pairwise\"][metric] = TABLE_METRICS[metric].compute_breakdown(\n                        real_data, synthetic_data, self._sdv_metadata\n                    )\n            elif metric in NUMERICAL_PRIVACY_METRICS:\n                metric_dict[\"privacy\"][metric] = NUMERICAL_PRIVACY_METRICS[metric].compute(\n                    real_data.dropna(),\n                    synthetic_data.dropna(),\n                    self._sdv_metadata,\n                    self._key_numerical_fields,\n                    self._sensitive_numerical_fields,\n                )\n            elif metric in CATEGORICAL_PRIVACY_METRICS:\n                metric_dict[\"privacy\"][metric] = CATEGORICAL_PRIVACY_METRICS[metric].compute(\n                    real_data.dropna(),\n                    synthetic_data.dropna(),\n                    self._sdv_metadata,\n                    self._key_categorical_fields,\n                    self._sensitive_categorical_fields,\n                )\n        return metric_dict\n\n    def _step(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame = None) -> dict[str, dict]:\n        \"\"\"\n        Run the two functions above (or only the tasks when no synthetic data is provided).\n\n        Args:\n            real_data: The real dataset to evaluate against.\n            synthetic_data: The synthetic dataset to evaluate.\n\n        Returns:\n            A dict of dicts, one for each metric grou, to populate a row of `self._evaluations` corresponding to the `synthetic_data`.\n        \"\"\"\n        if synthetic_data is None:\n            metric_dict = self._task_step(real_data)\n        else:\n            metric_dict = self._task_step(synthetic_data)\n            for metric in tqdm(self._metrics, desc=\"Running metrics\", leave=False):\n                metric_dict = self._compute_metric(metric_dict, metric, real_data, synthetic_data)\n        return metric_dict\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame.evaluate","title":"evaluate(real_dataset, synthetic_datasets)","text":"

Evaluate a set of synthetic datasets against a real dataset.

Parameters:

Name Type Description Default real_dataset DataFrame

The real dataset to evaluate against.

required synthetic_datasets list[dict[str, Any]]

The synthetic datasets to evaluate.

required Source code in src/nhssynth/modules/evaluation/utils.py
def evaluate(self, real_dataset: pd.DataFrame, synthetic_datasets: list[dict[str, Any]]) -> None:\n    \"\"\"\n    Evaluate a set of synthetic datasets against a real dataset.\n\n    Args:\n        real_dataset: The real dataset to evaluate against.\n        synthetic_datasets: The synthetic datasets to evaluate.\n    \"\"\"\n    assert not any(\"Real\" in i for i in synthetic_datasets.index), \"Real is a reserved dataset ID.\"\n    assert synthetic_datasets.index.is_unique, \"Dataset IDs must be unique.\"\n    self._evaluations = pd.DataFrame(index=synthetic_datasets.index, columns=self._metric_groups)\n    self._evaluations.loc[(\"Real\", None, None)] = self._step(real_dataset)\n    pbar = tqdm(synthetic_datasets.iterrows(), desc=\"Evaluating\", total=len(synthetic_datasets))\n    for i, dataset in pbar:\n        pbar.set_description(f\"Evaluating {i[0]}, repeat {i[1]}, config {i[2]}\")\n        self._evaluations.loc[i] = self._step(real_dataset, dataset.values[0])\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.EvalFrame.get_evaluations","title":"get_evaluations()","text":"

Unpack the self._evaluations dataframe, where each metric group is a column, into a dict of dataframes.

Returns:

Type Description dict[str, DataFrame]

A dict of dataframes, one for each metric group, containing the evaluations.

Source code in src/nhssynth/modules/evaluation/utils.py
def get_evaluations(self) -> dict[str, pd.DataFrame]:\n    \"\"\"\n    Unpack the `self._evaluations` dataframe, where each metric group is a column, into a dict of dataframes.\n\n    Returns:\n        A dict of dataframes, one for each metric group, containing the evaluations.\n    \"\"\"\n    assert hasattr(\n        self, \"_evaluations\"\n    ), \"You must first run `evaluate` on a `real_dataset` and set of `synthetic_datasets`.\"\n    return {\n        metric_group: pd.DataFrame(\n            self._evaluations[metric_group].values.tolist(), index=self._evaluations.index\n        ).dropna(how=\"all\")\n        for metric_group in self._metric_groups\n    }\n
"},{"location":"reference/modules/evaluation/utils/#nhssynth.modules.evaluation.utils.validate_metric_args","title":"validate_metric_args(args, fn_dataset, columns)","text":"

Validate the arguments for downstream tasks and Aequitas.

Parameters:

Name Type Description Default args Namespace

The argument namespace to validate.

required fn_dataset str

The name of the dataset.

required columns Index

The columns in the dataset.

required

Returns:

Type Description tuple[list[Task], Namespace]

The validated arguments, the list of tasks and the list of metrics.

Source code in src/nhssynth/modules/evaluation/utils.py
def validate_metric_args(\n    args: argparse.Namespace, fn_dataset: str, columns: pd.Index\n) -> tuple[list[Task], argparse.Namespace]:\n    \"\"\"\n    Validate the arguments for downstream tasks and Aequitas.\n\n    Args:\n        args: The argument namespace to validate.\n        fn_dataset: The name of the dataset.\n        columns: The columns in the dataset.\n\n    Returns:\n        The validated arguments, the list of tasks and the list of metrics.\n    \"\"\"\n    if args.downstream_tasks:\n        tasks = get_tasks(fn_dataset, args.tasks_dir)\n        if not tasks:\n            warnings.warn(\"No valid downstream tasks found.\")\n    else:\n        tasks = []\n    if args.aequitas:\n        if not args.downstream_tasks or not any([task.supports_aequitas for task in tasks]):\n            warnings.warn(\n                \"Aequitas can only work in context of downstream tasks involving binary classification problems.\"\n            )\n        if not args.aequitas_attributes:\n            warnings.warn(\"No attributes specified for Aequitas analysis, defaulting to all columns in the dataset.\")\n            args.aequitas_attributes = columns.tolist()\n        assert all(\n            [attr in columns for attr in args.aequitas_attributes]\n        ), \"Invalid attribute(s) specified for Aequitas analysis.\"\n    metrics = {}\n    for metric_group in METRIC_CHOICES:\n        selected_metrics = getattr(args, \"_\".join(metric_group.split()).lower() + \"_metrics\") or []\n        metrics.update({metric_name: METRIC_CHOICES[metric_group][metric_name] for metric_name in selected_metrics})\n    return args, tasks, metrics\n
"},{"location":"reference/modules/model/","title":"model","text":""},{"location":"reference/modules/model/io/","title":"io","text":""},{"location":"reference/modules/model/io/#nhssynth.modules.model.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_transformed, fn_metatransformer, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_transformed str

The name of the transformed data file.

required fn_metatransformer str

The name of the metatransformer file.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/model/io.py
def check_input_paths(\n    fn_dataset: str, fn_transformed: str, fn_metatransformer: str, dir_experiment: Path\n) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_transformed: The name of the transformed data file.\n        fn_metatransformer: The name of the metatransformer file.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset = Path(fn_dataset).stem\n    fn_transformed, fn_metatransformer = io.consistent_endings([fn_transformed, fn_metatransformer])\n    fn_transformed, fn_metatransformer = io.potential_suffixes([fn_transformed, fn_metatransformer], fn_dataset)\n    io.warn_if_path_supplied([fn_transformed, fn_metatransformer], dir_experiment)\n    io.check_exists([fn_transformed, fn_metatransformer], dir_experiment)\n    return fn_dataset, fn_transformed, fn_metatransformer\n
"},{"location":"reference/modules/model/io/#nhssynth.modules.model.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, dict[str, int], MetaTransformer]

The data, metadata and metatransformer.

Source code in src/nhssynth/modules/model/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, dict[str, int], MetaTransformer]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The data, metadata and metatransformer.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"transformed\", \"metatransformer\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"transformed\"],\n            args.module_handover[\"metatransformer\"],\n        )\n    else:\n        fn_dataset, fn_transformed, fn_metatransformer = check_input_paths(\n            args.dataset, args.transformed, args.metatransformer, dir_experiment\n        )\n\n        with open(dir_experiment / fn_transformed, \"rb\") as f:\n            data = pickle.load(f)\n        with open(dir_experiment / fn_metatransformer, \"rb\") as f:\n            mt = pickle.load(f)\n\n        return fn_dataset, data, mt\n
"},{"location":"reference/modules/model/run/","title":"run","text":""},{"location":"reference/modules/model/utils/","title":"utils","text":""},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.configs_from_arg_combinations","title":"configs_from_arg_combinations(args, arg_list)","text":"

Generates a list of configurations from a list of arguments. Each configuration is one of a cartesian product of the arguments provided and identified in arg_list.

Parameters:

Name Type Description Default args Namespace

The arguments.

required arg_list list[str]

The list of arguments to generate configurations from.

required

Returns:

Type Description list[dict[str, Any]]

A list of configurations.

Source code in src/nhssynth/modules/model/utils.py
def configs_from_arg_combinations(args: argparse.Namespace, arg_list: list[str]) -> list[dict[str, Any]]:\n    \"\"\"\n    Generates a list of configurations from a list of arguments. Each configuration is one of a cartesian product of\n    the arguments provided and identified in `arg_list`.\n\n    Args:\n        args: The arguments.\n        arg_list: The list of arguments to generate configurations from.\n\n    Returns:\n        A list of configurations.\n    \"\"\"\n    wrapped_args = {arg: wrap_arg(getattr(args, arg)) for arg in arg_list}\n    combinations = list(itertools.product(*wrapped_args.values()))\n    return [{k: v for k, v in zip(wrapped_args.keys(), values) if v is not None} for values in combinations]\n
"},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.get_experiments","title":"get_experiments(args)","text":"

Generates a dataframe of experiments from the arguments provided.

Parameters:

Name Type Description Default args Namespace

The arguments.

required

Returns:

Type Description DataFrame

A dataframe of experiments indexed by architecture, repeat and config ID.

Source code in src/nhssynth/modules/model/utils.py
def get_experiments(args: argparse.Namespace) -> pd.DataFrame:\n    \"\"\"\n    Generates a dataframe of experiments from the arguments provided.\n\n    Args:\n        args: The arguments.\n\n    Returns:\n        A dataframe of experiments indexed by architecture, repeat and config ID.\n    \"\"\"\n    experiments = pd.DataFrame(\n        columns=[\"architecture\", \"repeat\", \"config\", \"model_config\", \"seed\", \"train_config\", \"num_configs\"]\n    )\n    train_configs = configs_from_arg_combinations(args, [\"num_epochs\", \"patience\"])\n    for arch_name, repeat in itertools.product(*[wrap_arg(args.architecture), list(range(args.repeats))]):\n        arch = MODELS[arch_name]\n        model_configs = configs_from_arg_combinations(args, arch.get_args() + [\"batch_size\", \"use_gpu\"])\n        for i, (train_config, model_config) in enumerate(itertools.product(train_configs, model_configs)):\n            experiments.loc[len(experiments.index)] = {\n                \"architecture\": arch_name,\n                \"repeat\": repeat + 1,\n                \"config\": i + 1,\n                \"model_config\": model_config,\n                \"num_configs\": len(model_configs) * len(train_configs),\n                \"seed\": args.seed + repeat if args.seed else None,\n                \"train_config\": train_config,\n            }\n    return experiments.set_index([\"architecture\", \"repeat\", \"config\"], drop=True)\n
"},{"location":"reference/modules/model/utils/#nhssynth.modules.model.utils.wrap_arg","title":"wrap_arg(arg)","text":"

Wraps a single argument in a list if it is not already a list or tuple.

Parameters:

Name Type Description Default arg Any

The argument to wrap.

required

Returns:

Type Description Union[list, tuple]

The wrapped argument.

Source code in src/nhssynth/modules/model/utils.py
def wrap_arg(arg: Any) -> Union[list, tuple]:\n    \"\"\"\n    Wraps a single argument in a list if it is not already a list or tuple.\n\n    Args:\n        arg: The argument to wrap.\n\n    Returns:\n        The wrapped argument.\n    \"\"\"\n    if not isinstance(arg, list) and not isinstance(arg, tuple):\n        return [arg]\n    return arg\n
"},{"location":"reference/modules/model/common/","title":"common","text":""},{"location":"reference/modules/model/common/dp/","title":"dp","text":""},{"location":"reference/modules/model/common/dp/#nhssynth.modules.model.common.dp.DPMixin","title":"DPMixin","text":"

Bases: ABC

Mixin class to make a Model differentially private

Parameters:

Name Type Description Default target_epsilon float

The target epsilon for the model during training

3.0 target_delta Optional[float]

The target delta for the model during training

None max_grad_norm float

The maximum norm for the gradients, they are trimmed to this norm if they are larger

5.0 secure_mode bool

Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the csprng package

False

Attributes:

Name Type Description target_epsilon float

The target epsilon for the model during training

target_delta float

The target delta for the model during training

max_grad_norm float

The maximum norm for the gradients, they are trimmed to this norm if they are larger

secure_mode bool

Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the csprng package

Raises:

Type Description TypeError

If the inheritor is not a Model

Source code in src/nhssynth/modules/model/common/dp.py
class DPMixin(ABC):\n    \"\"\"\n    Mixin class to make a [`Model`][nhssynth.modules.model.common.model.Model] differentially private\n\n    Args:\n        target_epsilon: The target epsilon for the model during training\n        target_delta: The target delta for the model during training\n        max_grad_norm: The maximum norm for the gradients, they are trimmed to this norm if they are larger\n        secure_mode: Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the `csprng` package\n\n    Attributes:\n        target_epsilon: The target epsilon for the model during training\n        target_delta: The target delta for the model during training\n        max_grad_norm: The maximum norm for the gradients, they are trimmed to this norm if they are larger\n        secure_mode: Whether to use the 'secure mode' of PyTorch's DP-SGD implementation via the `csprng` package\n\n    Raises:\n        TypeError: If the inheritor is not a `Model`\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        target_epsilon: float = 3.0,\n        target_delta: Optional[float] = None,\n        max_grad_norm: float = 5.0,\n        secure_mode: bool = False,\n        **kwargs,\n    ):\n        if not isinstance(self, Model):\n            raise TypeError(\"DPMixin can only be used with Model classes\")\n        super(DPMixin, self).__init__(*args, **kwargs)\n        self.target_epsilon: float = target_epsilon\n        self.target_delta: float = target_delta or 1 / self.nrows\n        self.max_grad_norm: float = max_grad_norm\n        self.secure_mode: bool = secure_mode\n\n    def make_private(self, num_epochs: int, module: Optional[nn.Module] = None) -> GradSampleModule:\n        \"\"\"\n        Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.\n\n        Args:\n            num_epochs: The number of epochs to train for, used to calculate the privacy budget.\n            module: The module to make private.\n\n        Returns:\n            The privatised module.\n        \"\"\"\n        module = module or self\n        self.privacy_engine = PrivacyEngine(secure_mode=self.secure_mode)\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n            warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n            module, module.optim, self.data_loader = self.privacy_engine.make_private_with_epsilon(\n                module=module,\n                optimizer=module.optim,\n                data_loader=self.data_loader,\n                epochs=num_epochs,\n                target_epsilon=self.target_epsilon,\n                target_delta=self.target_delta,\n                max_grad_norm=self.max_grad_norm,\n            )\n        print(\n            f\"Using sigma={module.optim.noise_multiplier} and C={self.max_grad_norm} to target (\u03b5, \u03b4) = ({self.target_epsilon}, {self.target_delta})-differential privacy.\".format()\n        )\n        self.get_epsilon = self.privacy_engine.accountant.get_epsilon\n        return module\n\n    def _generate_metric_str(self, key) -> str:\n        \"\"\"Generates a string to display the current value of the metric `key`.\"\"\"\n        if key == \"Privacy\":\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n                warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n                val = self.get_epsilon(self.target_delta)\n            self.metrics[key] = np.append(self.metrics[key], val)\n            return f\"{(key + ' \u03b5 Spent:').ljust(self.max_length)}  {val:.4f}\"\n        else:\n            return super()._generate_metric_str(key)\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\"target_epsilon\", \"target_delta\", \"max_grad_norm\", \"secure_mode\"]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\"Privacy\"]\n\n    def _start_training(self, num_epochs, patience, displayed_metrics):\n        self.make_private(num_epochs)\n        super()._start_training(num_epochs, patience, displayed_metrics)\n
"},{"location":"reference/modules/model/common/dp/#nhssynth.modules.model.common.dp.DPMixin.make_private","title":"make_private(num_epochs, module=None)","text":"

Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.

Parameters:

Name Type Description Default num_epochs int

The number of epochs to train for, used to calculate the privacy budget.

required module Optional[Module]

The module to make private.

None

Returns:

Type Description GradSampleModule

The privatised module.

Source code in src/nhssynth/modules/model/common/dp.py
def make_private(self, num_epochs: int, module: Optional[nn.Module] = None) -> GradSampleModule:\n    \"\"\"\n    Make the passed module (or the full model if a module is not passed), and its associated optimizer and data loader private.\n\n    Args:\n        num_epochs: The number of epochs to train for, used to calculate the privacy budget.\n        module: The module to make private.\n\n    Returns:\n        The privatised module.\n    \"\"\"\n    module = module or self\n    self.privacy_engine = PrivacyEngine(secure_mode=self.secure_mode)\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\", message=\"invalid value encountered in log\")\n        warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha\")\n        module, module.optim, self.data_loader = self.privacy_engine.make_private_with_epsilon(\n            module=module,\n            optimizer=module.optim,\n            data_loader=self.data_loader,\n            epochs=num_epochs,\n            target_epsilon=self.target_epsilon,\n            target_delta=self.target_delta,\n            max_grad_norm=self.max_grad_norm,\n        )\n    print(\n        f\"Using sigma={module.optim.noise_multiplier} and C={self.max_grad_norm} to target (\u03b5, \u03b4) = ({self.target_epsilon}, {self.target_delta})-differential privacy.\".format()\n    )\n    self.get_epsilon = self.privacy_engine.accountant.get_epsilon\n    return module\n
"},{"location":"reference/modules/model/common/mlp/","title":"mlp","text":""},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MLP","title":"MLP","text":"

Bases: Module

Fully connected or residual neural nets for classification and regression.

"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MLP--parameters","title":"Parameters","text":"

task_type: str classification or regression n_units_int: int Number of features n_units_out: int Number of outputs n_layers_hidden: int Number of hidden layers n_units_hidden: int Number of hidden units in each layer nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu', 'tanh' or 'leaky_relu'. lr: float learning rate for optimizer. weight_decay: float l2 (ridge) penalty for the weights. n_iter: int Maximum number of iterations. batch_size: int Batch size n_iter_print: int Number of iterations after which to print updates and check the validation loss. random_state: int random_state used patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping dropout: float Dropout value. If 0, the dropout is not used. clipping_value: int, default 1 Gradients clipping value batch_norm: bool Enable/disable batch norm early_stopping: bool Enable/disable early stopping residual: bool Add residuals. loss: Callable Optional Custom loss function. If None, the loss is CrossEntropy for classification tasks, or RMSE for regression.

Source code in src/nhssynth/modules/model/common/mlp.py
class MLP(nn.Module):\n    \"\"\"\n    Fully connected or residual neural nets for classification and regression.\n\n    Parameters\n    ----------\n    task_type: str\n        classification or regression\n    n_units_int: int\n        Number of features\n    n_units_out: int\n        Number of outputs\n    n_layers_hidden: int\n        Number of hidden layers\n    n_units_hidden: int\n        Number of hidden units in each layer\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu', 'tanh' or 'leaky_relu'.\n    lr: float\n        learning rate for optimizer.\n    weight_decay: float\n        l2 (ridge) penalty for the weights.\n    n_iter: int\n        Maximum number of iterations.\n    batch_size: int\n        Batch size\n    n_iter_print: int\n        Number of iterations after which to print updates and check the validation loss.\n    random_state: int\n        random_state used\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    dropout: float\n        Dropout value. If 0, the dropout is not used.\n    clipping_value: int, default 1\n        Gradients clipping value\n    batch_norm: bool\n        Enable/disable batch norm\n    early_stopping: bool\n        Enable/disable early stopping\n    residual: bool\n        Add residuals.\n    loss: Callable\n        Optional Custom loss function. If None, the loss is CrossEntropy for classification tasks, or RMSE for regression.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_units_in: int,\n        n_units_out: int,\n        n_layers_hidden: int = 1,\n        n_units_hidden: int = 100,\n        activation: str = \"relu\",\n        activation_out: Optional[list[tuple[str, int]]] = None,\n        lr: float = 1e-3,\n        weight_decay: float = 1e-3,\n        opt_betas: tuple = (0.9, 0.999),\n        n_iter: int = 1000,\n        batch_size: int = 500,\n        n_iter_print: int = 100,\n        patience: int = 10,\n        n_iter_min: int = 100,\n        dropout: float = 0.1,\n        clipping_value: int = 1,\n        batch_norm: bool = False,\n        early_stopping: bool = True,\n        residual: bool = False,\n        loss: Optional[Callable] = None,\n    ) -> None:\n        super(MLP, self).__init__()\n        activation = ACTIVATION_FUNCTIONS[activation] if activation in ACTIVATION_FUNCTIONS else None\n\n        if n_units_in < 0:\n            raise ValueError(\"n_units_in must be >= 0\")\n        if n_units_out < 0:\n            raise ValueError(\"n_units_out must be >= 0\")\n\n        if residual:\n            block = ResidualLayer\n        else:\n            block = LinearLayer\n\n        # network\n        layers = []\n\n        if n_layers_hidden > 0:\n            layers.append(\n                block(\n                    n_units_in,\n                    n_units_hidden,\n                    batch_norm=batch_norm,\n                    activation=activation,\n                )\n            )\n            n_units_hidden += int(residual) * n_units_in\n\n            # add required number of layers\n            for i in range(n_layers_hidden - 1):\n                layers.append(\n                    block(\n                        n_units_hidden,\n                        n_units_hidden,\n                        batch_norm=batch_norm,\n                        activation=activation,\n                        dropout=dropout,\n                    )\n                )\n                n_units_hidden += int(residual) * n_units_hidden\n\n            # add final layers\n            layers.append(nn.Linear(n_units_hidden, n_units_out))\n        else:\n            layers = [nn.Linear(n_units_in, n_units_out)]\n\n        if activation_out is not None:\n            total_nonlin_len = 0\n            activations = []\n            for nonlin, nonlin_len in activation_out:\n                total_nonlin_len += nonlin_len\n                activations.append((ACTIVATION_FUNCTIONS[nonlin](), nonlin_len))\n\n            if total_nonlin_len != n_units_out:\n                raise RuntimeError(\n                    f\"Shape mismatch for the output layer. Expected length {n_units_out}, but got {activation_out} with length {total_nonlin_len}\"\n                )\n            layers.append(MultiActivationHead(activations))\n\n        self.model = nn.Sequential(*layers)\n\n        # optimizer\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.opt_betas = opt_betas\n        self.optimizer = torch.optim.Adam(\n            self.parameters(),\n            lr=self.lr,\n            weight_decay=self.weight_decay,\n            betas=self.opt_betas,\n        )\n\n        # training\n        self.n_iter = n_iter\n        self.n_iter_print = n_iter_print\n        self.n_iter_min = n_iter_min\n        self.batch_size = batch_size\n        self.patience = patience\n        self.clipping_value = clipping_value\n        self.early_stopping = early_stopping\n        if loss is not None:\n            self.loss = loss\n        else:\n            self.loss = nn.MSELoss()\n\n    def fit(self, X: np.ndarray, y: np.ndarray) -> \"MLP\":\n        Xt = self._check_tensor(X)\n        yt = self._check_tensor(y)\n\n        self._train(Xt, yt)\n\n        return self\n\n    def predict_proba(self, X: np.ndarray) -> np.ndarray:\n        if self.task_type != \"classification\":\n            raise ValueError(f\"Invalid task type for predict_proba {self.task_type}\")\n\n        with torch.no_grad():\n            Xt = self._check_tensor(X)\n\n            yt = self.forward(Xt)\n\n            return yt.cpu().numpy().squeeze()\n\n    def predict(self, X: np.ndarray) -> np.ndarray:\n        with torch.no_grad():\n            Xt = self._check_tensor(X)\n\n            yt = self.forward(Xt)\n\n            if self.task_type == \"classification\":\n                return np.argmax(yt.cpu().numpy().squeeze(), -1).squeeze()\n            else:\n                return yt.cpu().numpy().squeeze()\n\n    def score(self, X: np.ndarray, y: np.ndarray) -> float:\n        y_pred = self.predict(X)\n        if self.task_type == \"classification\":\n            return np.mean(y_pred == y)\n        else:\n            return np.mean(np.inner(y - y_pred, y - y_pred) / 2.0)\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        return self.model(X.float())\n\n    def _train_epoch(self, loader: DataLoader) -> float:\n        train_loss = []\n\n        for batch_ndx, sample in enumerate(loader):\n            self.optimizer.zero_grad()\n\n            X_next, y_next = sample\n            if len(X_next) < 2:\n                continue\n\n            preds = self.forward(X_next).squeeze()\n\n            batch_loss = self.loss(preds, y_next)\n\n            batch_loss.backward()\n\n            if self.clipping_value > 0:\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\n\n            self.optimizer.step()\n\n            train_loss.append(batch_loss.detach())\n\n        return torch.mean(torch.Tensor(train_loss))\n\n    def _train(self, X: torch.Tensor, y: torch.Tensor) -> \"MLP\":\n        X = self._check_tensor(X).float()\n        y = self._check_tensor(y).squeeze().float()\n        if self.task_type == \"classification\":\n            y = y.long()\n\n        # Load Dataset\n        dataset = TensorDataset(X, y)\n\n        train_size = int(0.8 * len(dataset))\n        test_size = len(dataset) - train_size\n        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])\n        loader = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=False)\n\n        # Setup the network and optimizer\n        val_loss_best = 1e12\n        patience = 0\n\n        # do training\n        for i in range(self.n_iter):\n            self._train_epoch(loader)\n\n            if self.early_stopping or i % self.n_iter_print == 0:\n                with torch.no_grad():\n                    X_val, y_val = test_dataset.dataset.tensors\n\n                    preds = self.forward(X_val).squeeze()\n                    val_loss = self.loss(preds, y_val)\n\n                    if self.early_stopping:\n                        if val_loss_best > val_loss:\n                            val_loss_best = val_loss\n                            patience = 0\n                        else:\n                            patience += 1\n\n                        if patience > self.patience and i > self.n_iter_min:\n                            break\n\n        return self\n\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\n        if isinstance(X, torch.Tensor):\n            return X\n        else:\n            return torch.from_numpy(np.asarray(X))\n\n    def __len__(self) -> int:\n        return len(self.model)\n
"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.MultiActivationHead","title":"MultiActivationHead","text":"

Bases: Module

Final layer with multiple activations. Useful for tabular data.

Source code in src/nhssynth/modules/model/common/mlp.py
class MultiActivationHead(nn.Module):\n    \"\"\"Final layer with multiple activations. Useful for tabular data.\"\"\"\n\n    def __init__(\n        self,\n        activations: list[tuple[nn.Module, int]],\n    ) -> None:\n        super(MultiActivationHead, self).__init__()\n        self.activations = []\n        self.activation_lengths = []\n\n        for activation, length in activations:\n            self.activations.append(activation)\n            self.activation_lengths.append(length)\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        if X.shape[-1] != np.sum(self.activation_lengths):\n            raise RuntimeError(\n                f\"Shape mismatch for the activations: expected {np.sum(self.activation_lengths)}. Got shape {X.shape}.\"\n            )\n\n        split = 0\n        out = torch.zeros(X.shape)\n\n        for activation, step in zip(self.activations, self.activation_lengths):\n            out[..., split : split + step] = activation(X[..., split : split + step])\n            split += step\n\n        return out\n
"},{"location":"reference/modules/model/common/mlp/#nhssynth.modules.model.common.mlp.SkipConnection","title":"SkipConnection(cls)","text":"

Wraps a model to add a skip connection from the input to the output.

Example:

ResidualBlock = SkipConnection(MLP) res_block = ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64) res_block(torch.ones(10, 10)).shape (10, 13)

Source code in src/nhssynth/modules/model/common/mlp.py
def SkipConnection(cls: Type[nn.Module]) -> Type[nn.Module]:\n    \"\"\"Wraps a model to add a skip connection from the input to the output.\n\n    Example:\n    >>> ResidualBlock = SkipConnection(MLP)\n    >>> res_block = ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64)\n    >>> res_block(torch.ones(10, 10)).shape\n    (10, 13)\n    \"\"\"\n\n    class Wrapper(cls):\n        pass\n\n    Wrapper._forward = cls.forward\n    Wrapper.forward = _forward_skip_connection\n    Wrapper.__name__ = f\"SkipConnection({cls.__name__})\"\n    Wrapper.__qualname__ = f\"SkipConnection({cls.__qualname__})\"\n    Wrapper.__doc__ = f\"\"\"(With skipped connection) {cls.__doc__}\"\"\"\n    return Wrapper\n
"},{"location":"reference/modules/model/common/model/","title":"model","text":""},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model","title":"Model","text":"

Bases: Module, ABC

Abstract base class for all NHSSynth models

Parameters:

Name Type Description Default data DataFrame

The data to train on

required metatransformer MetaTransformer

A MetaTransformer to use for converting the generated data to match the original data

required batch_size int

The batch size to use during training

32 use_gpu bool

Flag to determine whether to use the GPU (if available)

False

Attributes:

Name Type Description nrows

The number of rows in the data

ncols

The number of columns in the data

columns Index

The names of the columns in the data

metatransformer

The MetaTransformer (potentially) associated with the model

multi_column_indices list[list[int]]

A list of lists of column indices, where each sublist containts the indices for a one-hot encoded column

single_column_indices list[int]

Indices of all non-onehot columns

data_loader DataLoader

A PyTorch DataLoader for the data

private DataLoader

Whether the model is private, i.e. whether the DPMixin class has been inherited

device DataLoader

The device to use for training (CPU or GPU)

Raises:

Type Description TypeError

If the Model class is directly instantiated (i.e. not inherited)

AssertionError

If the number of columns in the data does not match the number of indices in multi_column_indices and single_column_indices

UserWarning

If use_gpu is True but no GPU is available

Source code in src/nhssynth/modules/model/common/model.py
class Model(nn.Module, ABC):\n    \"\"\"\n    Abstract base class for all NHSSynth models\n\n    Args:\n        data: The data to train on\n        metatransformer: A `MetaTransformer` to use for converting the generated data to match the original data\n        batch_size: The batch size to use during training\n        use_gpu: Flag to determine whether to use the GPU (if available)\n\n    Attributes:\n        nrows: The number of rows in the `data`\n        ncols: The number of columns in the `data`\n        columns: The names of the columns in the `data`\n        metatransformer: The `MetaTransformer` (potentially) associated with the model\n        multi_column_indices: A list of lists of column indices, where each sublist containts the indices for a one-hot encoded column\n        single_column_indices: Indices of all non-onehot columns\n        data_loader: A PyTorch DataLoader for the `data`\n        private: Whether the model is private, i.e. whether the `DPMixin` class has been inherited\n        device: The device to use for training (CPU or GPU)\n\n    Raises:\n        TypeError: If the `Model` class is directly instantiated (i.e. not inherited)\n        AssertionError: If the number of columns in the `data` does not match the number of indices in `multi_column_indices` and `single_column_indices`\n        UserWarning: If `use_gpu` is True but no GPU is available\n    \"\"\"\n\n    def __init__(\n        self,\n        data: pd.DataFrame,\n        metatransformer: MetaTransformer,\n        cond: Optional[Union[pd.DataFrame, pd.Series, np.ndarray]] = None,\n        batch_size: int = 32,\n        use_gpu: bool = False,\n    ) -> None:\n        if type(self) is Model:\n            raise TypeError(\"Cannot directly instantiate the `Model` class\")\n        super().__init__()\n\n        self.nrows, self.ncols = data.shape\n        self.columns: pd.Index = data.columns\n\n        self.batch_size = batch_size\n\n        self.metatransformer = metatransformer\n        self.multi_column_indices: list[list[int]] = metatransformer.multi_column_indices\n        self.single_column_indices: list[int] = metatransformer.single_column_indices\n        assert len(self.single_column_indices) + sum([len(x) for x in self.multi_column_indices]) == self.ncols\n\n        tensor_data = torch.Tensor(data.to_numpy())\n        self.cond_encoder: Optional[OneHotEncoder] = None\n        if cond is not None:\n            cond = np.asarray(cond)\n            if len(cond.shape) == 1:\n                cond = cond.reshape(-1, 1)\n            self.cond_encoder = OneHotEncoder(handle_unknown=\"ignore\").fit(cond)\n            cond = self.cond_encoder.transform(cond).toarray()\n            self.n_units_conditional = cond.shape[-1]\n            dataset = TensorDataset(tensor_data, cond)\n        else:\n            self.n_units_conditional = 0\n            dataset = TensorDataset(tensor_data)\n\n        self.data_loader: DataLoader = DataLoader(\n            dataset,\n            pin_memory=True,\n            batch_size=self.batch_size,\n        )\n        self.setup_device(use_gpu)\n\n    def setup_device(self, use_gpu: bool) -> None:\n        \"\"\"Sets up the device to use for training (CPU or GPU) depending on `use_gpu` and device availability.\"\"\"\n        if use_gpu:\n            if torch.cuda.is_available():\n                self.device: torch.device = torch.device(\"cuda:0\")\n            else:\n                warnings.warn(\"`use_gpu` was provided but no GPU is available, using CPU\")\n        self.device: torch.device = torch.device(\"cpu\")\n\n    def save(self, filename: str) -> None:\n        \"\"\"Saves the model to `filename`.\"\"\"\n        torch.save(self.state_dict(), filename)\n\n    def load(self, path: str) -> None:\n        \"\"\"Loads the model from `path`.\"\"\"\n        self.load_state_dict(torch.load(path))\n\n    @classmethod\n    @abstractmethod\n    def get_args() -> list[str]:\n        \"\"\"Returns the list of arguments to look for in an `argparse.Namespace`, these must map to the arguments of the inheritor.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def get_metrics() -> list[str]:\n        \"\"\"Returns the list of metrics to track during training.\"\"\"\n        raise NotImplementedError\n\n    def _start_training(self, num_epochs: int, patience: int, displayed_metrics: list[str]) -> None:\n        \"\"\"\n        Initialises the training process.\n\n        Args:\n            num_epochs: The number of epochs to train for\n            patience: The number of epochs to wait before stopping training early if the loss does not improve\n            displayed_metrics: The metrics to display during training, this should be set to an empty list if running `train` in a notebook or the output may be messy\n\n        Attributes:\n            metrics: A dictionary of lists of tracked metrics, where each list contains the values for each batch\n            stats_bars: A dictionary of tqdm status bars for each tracked metric\n            max_length: The maximum length of the tracked metric names, used for formatting the tqdm status bars\n            start_time: The time at which training started\n            update_time: The time at which the tqdm status bars were last updated\n        \"\"\"\n        self.num_epochs = num_epochs\n        self.patience = patience\n        self.metrics = {metric: np.empty(0, dtype=float) for metric in self.get_metrics()}\n        displayed_metrics = displayed_metrics or self.get_metrics()\n        self.stats_bars = {\n            metric: tqdm(total=0, desc=\"\", position=i, bar_format=\"{desc}\", leave=True)\n            for i, metric in enumerate(displayed_metrics)\n        }\n        self.max_length = max([len(add_spaces_before_caps(s)) + 5 for s in displayed_metrics] + [20])\n        self.start_time = self.update_time = time.time()\n\n    def _generate_metric_str(self, key) -> str:\n        \"\"\"Generates a string to display the current value of the metric `key`.\"\"\"\n        return f\"{(add_spaces_before_caps(key) + ':').ljust(self.max_length)}  {np.mean(self.metrics[key][-len(self.data_loader) :]):.4f}\"\n\n    def _record_metrics(self, losses):\n        \"\"\"Records the metrics for the current batch to file and updates the tqdm status bars.\"\"\"\n        for key in self.metrics.keys():\n            if key in losses:\n                if losses[key]:\n                    self.metrics[key] = np.append(\n                        self.metrics[key], losses[key].item() if isinstance(losses[key], torch.Tensor) else losses[key]\n                    )\n        if time.time() - self.update_time > 0.5:\n            for key, stats_bar in self.stats_bars.items():\n                stats_bar.set_description_str(self._generate_metric_str(key))\n                self.update_time = time.time()\n\n    def _check_patience(self, epoch: int, metric: float) -> bool:\n        \"\"\"Maintains `_min_metric` and `_stop_counter` to determine whether to stop training early according to `patience`.\"\"\"\n        if epoch == 0:\n            self._stop_counter = 0\n            self._min_metric = metric\n            self._patience_delta = self._min_metric / 1e4\n        if metric < (self._min_metric - self._patience_delta):\n            self._min_metric = metric\n            self._stop_counter = 0  # Set counter to zero\n        else:  # elbo has not improved\n            self._stop_counter += 1\n        return self._stop_counter == self.patience\n\n    def _finish_training(self, num_epochs: int) -> None:\n        \"\"\"Closes each of the tqdm status bars and prints the time taken to do `num_epochs`.\"\"\"\n        for stats_bar in self.stats_bars.values():\n            stats_bar.close()\n        tqdm.write(f\"Completed {num_epochs} epochs in {time.time() - self.start_time:.2f} seconds.\\033[0m\")\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.get_args","title":"get_args() abstractmethod classmethod","text":"

Returns the list of arguments to look for in an argparse.Namespace, these must map to the arguments of the inheritor.

Source code in src/nhssynth/modules/model/common/model.py
@classmethod\n@abstractmethod\ndef get_args() -> list[str]:\n    \"\"\"Returns the list of arguments to look for in an `argparse.Namespace`, these must map to the arguments of the inheritor.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.get_metrics","title":"get_metrics() abstractmethod classmethod","text":"

Returns the list of metrics to track during training.

Source code in src/nhssynth/modules/model/common/model.py
@classmethod\n@abstractmethod\ndef get_metrics() -> list[str]:\n    \"\"\"Returns the list of metrics to track during training.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.load","title":"load(path)","text":"

Loads the model from path.

Source code in src/nhssynth/modules/model/common/model.py
def load(self, path: str) -> None:\n    \"\"\"Loads the model from `path`.\"\"\"\n    self.load_state_dict(torch.load(path))\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.save","title":"save(filename)","text":"

Saves the model to filename.

Source code in src/nhssynth/modules/model/common/model.py
def save(self, filename: str) -> None:\n    \"\"\"Saves the model to `filename`.\"\"\"\n    torch.save(self.state_dict(), filename)\n
"},{"location":"reference/modules/model/common/model/#nhssynth.modules.model.common.model.Model.setup_device","title":"setup_device(use_gpu)","text":"

Sets up the device to use for training (CPU or GPU) depending on use_gpu and device availability.

Source code in src/nhssynth/modules/model/common/model.py
def setup_device(self, use_gpu: bool) -> None:\n    \"\"\"Sets up the device to use for training (CPU or GPU) depending on `use_gpu` and device availability.\"\"\"\n    if use_gpu:\n        if torch.cuda.is_available():\n            self.device: torch.device = torch.device(\"cuda:0\")\n        else:\n            warnings.warn(\"`use_gpu` was provided but no GPU is available, using CPU\")\n    self.device: torch.device = torch.device(\"cpu\")\n
"},{"location":"reference/modules/model/models/","title":"models","text":""},{"location":"reference/modules/model/models/dpvae/","title":"dpvae","text":""},{"location":"reference/modules/model/models/dpvae/#nhssynth.modules.model.models.dpvae.DPVAE","title":"DPVAE","text":"

Bases: DPMixin, VAE

A differentially private VAE. Accepts VAE arguments as well as DPMixin arguments.

Source code in src/nhssynth/modules/model/models/dpvae.py
class DPVAE(DPMixin, VAE):\n    \"\"\"\n    A differentially private VAE. Accepts [`VAE`][nhssynth.modules.model.models.vae.VAE] arguments\n    as well as [`DPMixin`][nhssynth.modules.model.common.dp.DPMixin] arguments.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        target_epsilon: float = 3.0,\n        target_delta: Optional[float] = None,\n        max_grad_norm: float = 5.0,\n        secure_mode: bool = False,\n        shared_optimizer: bool = False,\n        **kwargs,\n    ) -> None:\n        super(DPVAE, self).__init__(\n            *args,\n            target_epsilon=target_epsilon,\n            target_delta=target_delta,\n            max_grad_norm=max_grad_norm,\n            secure_mode=secure_mode,\n            # TODO fix shared_optimizer workflow for DP models\n            shared_optimizer=False,\n            **kwargs,\n        )\n\n    def make_private(self, num_epochs: int) -> GradSampleModule:\n        \"\"\"\n        Make the [`Decoder`][nhssynth.modules.model.models.vae.Decoder] differentially private\n        unless `shared_optimizer` is True, in which case the whole VAE will be privatised.\n\n        Args:\n            num_epochs: The number of epochs to train for\n        \"\"\"\n        if self.shared_optimizer:\n            super().make_private(num_epochs)\n        else:\n            self.decoder = super().make_private(num_epochs, self.decoder)\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return VAE.get_args() + DPMixin.get_args()\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return VAE.get_metrics() + DPMixin.get_metrics()\n
"},{"location":"reference/modules/model/models/dpvae/#nhssynth.modules.model.models.dpvae.DPVAE.make_private","title":"make_private(num_epochs)","text":"

Make the Decoder differentially private unless shared_optimizer is True, in which case the whole VAE will be privatised.

Parameters:

Name Type Description Default num_epochs int

The number of epochs to train for

required Source code in src/nhssynth/modules/model/models/dpvae.py
def make_private(self, num_epochs: int) -> GradSampleModule:\n    \"\"\"\n    Make the [`Decoder`][nhssynth.modules.model.models.vae.Decoder] differentially private\n    unless `shared_optimizer` is True, in which case the whole VAE will be privatised.\n\n    Args:\n        num_epochs: The number of epochs to train for\n    \"\"\"\n    if self.shared_optimizer:\n        super().make_private(num_epochs)\n    else:\n        self.decoder = super().make_private(num_epochs, self.decoder)\n
"},{"location":"reference/modules/model/models/gan/","title":"gan","text":""},{"location":"reference/modules/model/models/gan/#nhssynth.modules.model.models.gan.GAN","title":"GAN","text":"

Bases: Model

Basic GAN implementation.

Parameters:

Name Type Description Default n_units_conditional int

int Number of conditional units

0 generator_n_layers_hidden int

int Number of hidden layers in the generator

2 generator_n_units_hidden int

int Number of hidden units in each layer of the Generator

250 generator_activation str

string, default 'elu' Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.

'leaky_relu' generator_n_iter

int Maximum number of iterations in the Generator.

required generator_batch_norm bool

bool Enable/disable batch norm for the generator

False generator_dropout float

float Dropout value. If 0, the dropout is not used.

0 generator_residual bool

bool Use residuals for the generator

True generator_activation_out

Optional[List[Tuple[str, int]]] List of activations. Useful with the TabularEncoder

required generator_lr float

float = 2e-4 Generator learning rate, used by the Adam optimizer

0.0002 generator_weight_decay

float = 1e-3 Generator weight decay, used by the Adam optimizer

required generator_opt_betas tuple

tuple = (0.9, 0.999) Generator initial decay rates, used by the Adam Optimizer

(0.9, 0.999) generator_extra_penalty_cbks

List[Callable] Additional loss callabacks for the generator. Used by the TabularGAN for the conditional loss

required discriminator_n_layers_hidden int

int Number of hidden layers in the discriminator

3 discriminator_n_units_hidden int

int Number of hidden units in each layer of the discriminator

300 discriminator_activation str

string, default 'relu' Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.

'leaky_relu' discriminator_batch_norm bool

bool Enable/disable batch norm for the discriminator

False discriminator_dropout float

float Dropout value for the discriminator. If 0, the dropout is not used.

0.1 discriminator_lr float

float Discriminator learning rate, used by the Adam optimizer

0.0002 discriminator_weight_decay

float Discriminator weight decay, used by the Adam optimizer

required discriminator_opt_betas tuple

tuple Initial weight decays for the Adam optimizer

(0.9, 0.999) clipping_value int

int, default 0 Gradients clipping value. Zero disables the feature

0 lambda_gradient_penalty float

float = 10 Weight for the gradient penalty

10 Source code in src/nhssynth/modules/model/models/gan.py
class GAN(Model):\n    \"\"\"\n    Basic GAN implementation.\n\n    Args:\n        n_units_conditional: int\n            Number of conditional units\n        generator_n_layers_hidden: int\n            Number of hidden layers in the generator\n        generator_n_units_hidden: int\n            Number of hidden units in each layer of the Generator\n        generator_activation: string, default 'elu'\n            Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n        generator_n_iter: int\n            Maximum number of iterations in the Generator.\n        generator_batch_norm: bool\n            Enable/disable batch norm for the generator\n        generator_dropout: float\n            Dropout value. If 0, the dropout is not used.\n        generator_residual: bool\n            Use residuals for the generator\n        generator_activation_out: Optional[List[Tuple[str, int]]]\n            List of activations. Useful with the TabularEncoder\n        generator_lr: float = 2e-4\n            Generator learning rate, used by the Adam optimizer\n        generator_weight_decay: float = 1e-3\n            Generator weight decay, used by the Adam optimizer\n        generator_opt_betas: tuple = (0.9, 0.999)\n            Generator initial decay rates, used by the Adam Optimizer\n        generator_extra_penalty_cbks: List[Callable]\n            Additional loss callabacks for the generator. Used by the TabularGAN for the conditional loss\n        discriminator_n_layers_hidden: int\n            Number of hidden layers in the discriminator\n        discriminator_n_units_hidden: int\n            Number of hidden units in each layer of the discriminator\n        discriminator_activation: string, default 'relu'\n            Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n        discriminator_batch_norm: bool\n            Enable/disable batch norm for the discriminator\n        discriminator_dropout: float\n            Dropout value for the discriminator. If 0, the dropout is not used.\n        discriminator_lr: float\n            Discriminator learning rate, used by the Adam optimizer\n        discriminator_weight_decay: float\n            Discriminator weight decay, used by the Adam optimizer\n        discriminator_opt_betas: tuple\n            Initial weight decays for the Adam optimizer\n        clipping_value: int, default 0\n            Gradients clipping value. Zero disables the feature\n        lambda_gradient_penalty: float = 10\n            Weight for the gradient penalty\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        n_units_conditional: int = 0,\n        generator_n_layers_hidden: int = 2,\n        generator_n_units_hidden: int = 250,\n        generator_activation: str = \"leaky_relu\",\n        generator_batch_norm: bool = False,\n        generator_dropout: float = 0,\n        generator_lr: float = 2e-4,\n        generator_residual: bool = True,\n        generator_opt_betas: tuple = (0.9, 0.999),\n        discriminator_n_layers_hidden: int = 3,\n        discriminator_n_units_hidden: int = 300,\n        discriminator_activation: str = \"leaky_relu\",\n        discriminator_batch_norm: bool = False,\n        discriminator_dropout: float = 0.1,\n        discriminator_lr: float = 2e-4,\n        discriminator_opt_betas: tuple = (0.9, 0.999),\n        clipping_value: int = 0,\n        lambda_gradient_penalty: float = 10,\n        **kwargs,\n    ) -> None:\n        super(GAN, self).__init__(*args, **kwargs)\n\n        self.generator_n_units_hidden = generator_n_units_hidden\n        self.n_units_conditional = n_units_conditional\n\n        self.generator = MLP(\n            n_units_in=generator_n_units_hidden + n_units_conditional,\n            n_units_out=self.ncols,\n            n_layers_hidden=generator_n_layers_hidden,\n            n_units_hidden=generator_n_units_hidden,\n            activation=generator_activation,\n            # nonlin_out=generator_activation_out,\n            batch_norm=generator_batch_norm,\n            dropout=generator_dropout,\n            lr=generator_lr,\n            residual=generator_residual,\n            opt_betas=generator_opt_betas,\n        ).to(self.device)\n\n        self.discriminator = MLP(\n            n_units_in=self.ncols + n_units_conditional,\n            n_units_out=1,\n            n_layers_hidden=discriminator_n_layers_hidden,\n            n_units_hidden=discriminator_n_units_hidden,\n            activation=discriminator_activation,\n            activation_out=[(\"none\", 1)],\n            batch_norm=discriminator_batch_norm,\n            dropout=discriminator_dropout,\n            lr=discriminator_lr,\n            opt_betas=discriminator_opt_betas,\n        ).to(self.device)\n\n        self.clipping_value = clipping_value\n        self.lambda_gradient_penalty = lambda_gradient_penalty\n\n        def gen_fake_labels(X: torch.Tensor) -> torch.Tensor:\n            return torch.zeros((len(X),), device=self.device)\n\n        def gen_true_labels(X: torch.Tensor) -> torch.Tensor:\n            return torch.ones((len(X),), device=self.device)\n\n        self.fake_labels_generator = gen_fake_labels\n        self.true_labels_generator = gen_true_labels\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\n            \"n_units_conditional\",\n            \"generator_n_layers_hidden\",\n            \"generator_n_units_hidden\",\n            \"generator_activation\",\n            \"generator_batch_norm\",\n            \"generator_dropout\",\n            \"generator_lr\",\n            \"generator_residual\",\n            \"generator_opt_betas\",\n            \"discriminator_n_layers_hidden\",\n            \"discriminator_n_units_hidden\",\n            \"discriminator_activation\",\n            \"discriminator_batch_norm\",\n            \"discriminator_dropout\",\n            \"discriminator_lr\",\n            \"discriminator_opt_betas\",\n            \"clipping_value\",\n            \"lambda_gradient_penalty\",\n        ]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\"GLoss\", \"DLoss\"]\n\n    def generate(self, N: int, cond: Optional[np.ndarray] = None) -> np.ndarray:\n        N = N or self.nrows\n        self.generator.eval()\n\n        condt: Optional[torch.Tensor] = None\n        if cond is not None:\n            condt = self._check_tensor(cond)\n        with torch.no_grad():\n            return self.metatransformer.inverse_apply(\n                pd.DataFrame(self(N, condt).detach().cpu().numpy(), columns=self.columns)\n            )\n\n    def forward(\n        self,\n        N: int,\n        cond: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if cond is None and self.n_units_conditional > 0:\n            # sample from the original conditional\n            if self._original_cond is None:\n                raise ValueError(\"Invalid original conditional. Provide a valid value.\")\n            cond_idxs = torch.randint(len(self._original_cond), (N,))\n            cond = self._original_cond[cond_idxs]\n\n        if cond is not None and len(cond.shape) == 1:\n            cond = cond.reshape(-1, 1)\n\n        if cond is not None and len(cond) != N:\n            raise ValueError(\"cond length must match N\")\n\n        fixed_noise = torch.randn(N, self.generator_n_units_hidden, device=self.device)\n        fixed_noise = self._append_optional_cond(fixed_noise, cond)\n\n        return self.generator(fixed_noise)\n\n    def _train_epoch_generator(\n        self,\n        X: torch.Tensor,\n        cond: Optional[torch.Tensor],\n    ) -> float:\n        # Update the G network\n        self.generator.train()\n        self.generator.optimizer.zero_grad()\n\n        real_X_raw = X.to(self.device)\n        real_X = self._append_optional_cond(real_X_raw, cond)\n        batch_size = len(real_X)\n\n        noise = torch.randn(batch_size, self.generator_n_units_hidden, device=self.device)\n        noise = self._append_optional_cond(noise, cond)\n\n        fake_raw = self.generator(noise)\n        fake = self._append_optional_cond(fake_raw, cond)\n\n        output = self.discriminator(fake).squeeze().float()\n        # Calculate G's loss based on this output\n        errG = -torch.mean(output)\n        if hasattr(self, \"generator_extra_penalty_cbks\"):\n            for extra_loss in self.generator_extra_penalty_cbks:\n                errG += extra_loss(\n                    real_X_raw,\n                    fake_raw,\n                    cond=cond,\n                )\n\n        # Calculate gradients for G\n        errG.backward()\n\n        # Update G\n        if self.clipping_value > 0:\n            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.clipping_value)\n        self.generator.optimizer.step()\n\n        if torch.isnan(errG):\n            raise RuntimeError(\"NaNs detected in the generator loss\")\n\n        # Return loss\n        return errG.item()\n\n    def _train_epoch_discriminator(\n        self,\n        X: torch.Tensor,\n        cond: Optional[torch.Tensor],\n    ) -> float:\n        # Update the D network\n        self.discriminator.train()\n\n        errors = []\n\n        batch_size = min(self.batch_size, len(X))\n\n        # Train with all-real batch\n        real_X = X.to(self.device)\n        real_X = self._append_optional_cond(real_X, cond)\n\n        real_labels = self.true_labels_generator(X).to(self.device).squeeze()\n        real_output = self.discriminator(real_X).squeeze().float()\n\n        # Train with all-fake batch\n        noise = torch.randn(batch_size, self.generator_n_units_hidden, device=self.device)\n        noise = self._append_optional_cond(noise, cond)\n\n        fake_raw = self.generator(noise)\n        fake = self._append_optional_cond(fake_raw, cond)\n\n        fake_labels = self.fake_labels_generator(fake_raw).to(self.device).squeeze().float()\n        fake_output = self.discriminator(fake.detach()).squeeze()\n\n        # Compute errors. Some fake inputs might be marked as real for privacy guarantees.\n\n        real_real_output = real_output[(real_labels * real_output) != 0]\n        real_fake_output = fake_output[(fake_labels * fake_output) != 0]\n        errD_real = torch.mean(torch.concat((real_real_output, real_fake_output)))\n\n        fake_real_output = real_output[((1 - real_labels) * real_output) != 0]\n        fake_fake_output = fake_output[((1 - fake_labels) * fake_output) != 0]\n        errD_fake = torch.mean(torch.concat((fake_real_output, fake_fake_output)))\n\n        penalty = self._loss_gradient_penalty(\n            real_samples=real_X,\n            fake_samples=fake,\n            batch_size=batch_size,\n        )\n        errD = -errD_real + errD_fake\n\n        self.discriminator.optimizer.zero_grad()\n        if isinstance(self, DPMixin):\n            # Adversarial loss\n            # 1. split fwd-bkwd on fake and real images into two explicit blocks.\n            # 2. no need to compute per_sample_gardients on fake data, disable hooks.\n            # 3. re-enable hooks to obtain per_sample_gardients for real data.\n            # fake fwd-bkwd\n            self.discriminator.disable_hooks()\n            penalty.backward(retain_graph=True)\n            errD_fake.backward(retain_graph=True)\n\n            self.discriminator.enable_hooks()\n            errD_real.backward()  # HACK: calling bkwd without zero_grad() accumulates param gradients\n        else:\n            penalty.backward(retain_graph=True)\n            errD.backward()\n\n        # Update D\n        if self.clipping_value > 0:\n            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.clipping_value)\n        self.discriminator.optimizer.step()\n\n        errors.append(errD.item())\n\n        if np.isnan(np.mean(errors)):\n            raise RuntimeError(\"NaNs detected in the discriminator loss\")\n\n        return np.mean(errors)\n\n    def _train_epoch(self) -> Tuple[float, float]:\n        for data in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n            cond: Optional[torch.Tensor] = None\n            if self.n_units_conditional > 0:\n                X, cond = data\n            else:\n                X = data[0]\n\n            losses = {\n                \"DLoss\": self._train_epoch_discriminator(X, cond),\n                \"GLoss\": self._train_epoch_generator(X, cond),\n            }\n            self._record_metrics(losses)\n\n        return np.mean(self.metrics[\"GLoss\"][-len(self.data_loader) :]), np.mean(\n            self.metrics[\"DLoss\"][-len(self.data_loader) :]\n        )\n\n    def train(\n        self,\n        num_epochs: int = 100,\n        patience: int = 5,\n        displayed_metrics: list[str] = [\"GLoss\", \"DLoss\"],\n    ) -> tuple[int, dict[str, np.ndarray]]:\n        self._start_training(num_epochs, patience, displayed_metrics)\n\n        for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n            losses = self._train_epoch()\n            if self._check_patience(epoch, losses[0]) and self._check_patience(epoch, losses[1]):\n                num_epochs = epoch + 1\n                break\n\n        self._finish_training(num_epochs)\n        return (num_epochs, self.metrics)\n\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\n        if isinstance(X, torch.Tensor):\n            return X.to(self.device)\n        else:\n            return torch.from_numpy(np.asarray(X)).to(self.device)\n\n    def _loss_gradient_penalty(\n        self,\n        real_samples: torch.tensor,\n        fake_samples: torch.Tensor,\n        batch_size: int,\n    ) -> torch.Tensor:\n        \"\"\"Calculates the gradient penalty loss for WGAN GP\"\"\"\n        # Random weight term for interpolation between real and fake samples\n        alpha = torch.rand([batch_size, 1]).to(self.device)\n        # Get random interpolation between real and fake samples\n        interpolated = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)\n        d_interpolated = self.discriminator(interpolated).squeeze()\n        labels = torch.ones((len(interpolated),), device=self.device)\n\n        # Get gradient w.r.t. interpolates\n        gradients = torch.autograd.grad(\n            outputs=d_interpolated,\n            inputs=interpolated,\n            grad_outputs=labels,\n            create_graph=True,\n            retain_graph=True,\n            only_inputs=True,\n            allow_unused=True,\n        )[0]\n        gradients = gradients.view(gradients.size(0), -1)\n        gradient_penalty = ((gradients.norm(2, dim=-1) - 1) ** 2).mean()\n        return self.lambda_gradient_penalty * gradient_penalty\n\n    def _append_optional_cond(self, X: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:\n        if cond is None:\n            return X\n\n        return torch.cat([X, cond], dim=1)\n
"},{"location":"reference/modules/model/models/vae/","title":"vae","text":""},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.Decoder","title":"Decoder","text":"

Bases: Module

Decoder, takes in z and outputs reconstruction

Source code in src/nhssynth/modules/model/models/vae.py
class Decoder(nn.Module):\n    \"\"\"Decoder, takes in z and outputs reconstruction\"\"\"\n\n    def __init__(\n        self,\n        output_dim: int,\n        latent_dim: int,\n        hidden_dim: int,\n        activation: str,\n        learning_rate: float,\n        shared_optimizer: bool,\n    ) -> None:\n        super().__init__()\n        activation = ACTIVATION_FUNCTIONS[activation]\n        self.net = nn.Sequential(\n            nn.Linear(latent_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, output_dim),\n        )\n        if not shared_optimizer:\n            self.optim = torch.optim.Adam(self.parameters(), lr=learning_rate)\n\n    def forward(self, z):\n        return self.net(z)\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.Encoder","title":"Encoder","text":"

Bases: Module

Encoder, takes in x and outputs mu_z, sigma_z (diagonal Gaussian variational posterior assumed)

Source code in src/nhssynth/modules/model/models/vae.py
class Encoder(nn.Module):\n    \"\"\"Encoder, takes in x and outputs mu_z, sigma_z (diagonal Gaussian variational posterior assumed)\"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        latent_dim: int,\n        hidden_dim: int,\n        activation: str,\n        learning_rate: float,\n        shared_optimizer: bool,\n    ) -> None:\n        super().__init__()\n        activation = ACTIVATION_FUNCTIONS[activation]\n        self.latent_dim = latent_dim\n        self.net = nn.Sequential(\n            nn.Linear(input_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, hidden_dim),\n            activation(),\n            nn.Linear(hidden_dim, 2 * latent_dim),\n        )\n        if not shared_optimizer:\n            self.optim = torch.optim.Adam(self.parameters(), lr=learning_rate)\n\n    def forward(self, x):\n        outs = self.net(x)\n        mu_z = outs[:, : self.latent_dim]\n        logsigma_z = outs[:, self.latent_dim :]\n        return mu_z, logsigma_z\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.VAE","title":"VAE","text":"

Bases: Model

A Variational Autoencoder (VAE) model. Accepts Model arguments as well as the following:

Parameters:

Name Type Description Default encoder_latent_dim int

The dimensionality of the latent space.

256 encoder_hidden_dim int

The dimensionality of the hidden layers in the encoder.

256 encoder_activation str

The activation function to use in the encoder.

'leaky_relu' encoder_learning_rate float

The learning rate for the encoder.

0.001 decoder_latent_dim int

The dimensionality of the hidden layers in the decoder.

256 decoder_hidden_dim int

The dimensionality of the hidden layers in the decoder.

32 decoder_activation str

The activation function to use in the decoder.

'leaky_relu' decoder_learning_rate float

The learning rate for the decoder.

0.001 shared_optimizer bool

Whether to use a shared optimizer for the encoder and decoder.

True Source code in src/nhssynth/modules/model/models/vae.py
class VAE(Model):\n    \"\"\"\n    A Variational Autoencoder (VAE) model. Accepts [`Model`][nhssynth.modules.model.common.model.Model] arguments as well as the following:\n\n    Args:\n        encoder_latent_dim: The dimensionality of the latent space.\n        encoder_hidden_dim: The dimensionality of the hidden layers in the encoder.\n        encoder_activation: The activation function to use in the encoder.\n        encoder_learning_rate: The learning rate for the encoder.\n        decoder_latent_dim: The dimensionality of the hidden layers in the decoder.\n        decoder_hidden_dim: The dimensionality of the hidden layers in the decoder.\n        decoder_activation: The activation function to use in the decoder.\n        decoder_learning_rate: The learning rate for the decoder.\n        shared_optimizer: Whether to use a shared optimizer for the encoder and decoder.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        encoder_latent_dim: int = 256,\n        encoder_hidden_dim: int = 256,\n        encoder_activation: str = \"leaky_relu\",\n        encoder_learning_rate: float = 1e-3,\n        decoder_latent_dim: int = 256,\n        decoder_hidden_dim: int = 32,\n        decoder_activation: str = \"leaky_relu\",\n        decoder_learning_rate: float = 1e-3,\n        shared_optimizer: bool = True,\n        **kwargs,\n    ) -> None:\n        super(VAE, self).__init__(*args, **kwargs)\n\n        self.shared_optimizer = shared_optimizer\n        self.encoder = Encoder(\n            input_dim=self.ncols,\n            latent_dim=encoder_latent_dim,\n            hidden_dim=encoder_hidden_dim,\n            activation=encoder_activation,\n            learning_rate=encoder_learning_rate,\n            shared_optimizer=self.shared_optimizer,\n        ).to(self.device)\n        self.decoder = Decoder(\n            output_dim=self.ncols,\n            latent_dim=decoder_latent_dim,\n            hidden_dim=decoder_hidden_dim,\n            activation=decoder_activation,\n            learning_rate=decoder_learning_rate,\n            shared_optimizer=self.shared_optimizer,\n        ).to(self.device)\n        self.noiser = Noiser(\n            len(self.single_column_indices),\n        ).to(self.device)\n        if self.shared_optimizer:\n            assert (\n                encoder_learning_rate == decoder_learning_rate\n            ), \"If `shared_optimizer` is True, `encoder_learning_rate` must equal `decoder_learning_rate`\"\n            self.optim = torch.optim.Adam(\n                list(self.encoder.parameters()) + list(self.decoder.parameters()),\n                lr=encoder_learning_rate,\n            )\n            self.zero_grad = self.optim.zero_grad\n            self.step = self.optim.step\n        else:\n            self.zero_grad = lambda: (self.encoder.optim.zero_grad(), self.decoder.optim.zero_grad())\n            self.step = lambda: (self.encoder.optim.step(), self.decoder.optim.step())\n\n    @classmethod\n    def get_args(cls) -> list[str]:\n        return [\n            \"encoder_latent_dim\",\n            \"encoder_hidden_dim\",\n            \"encoder_activation\",\n            \"encoder_learning_rate\",\n            \"decoder_latent_dim\",\n            \"decoder_hidden_dim\",\n            \"decoder_activation\",\n            \"decoder_learning_rate\",\n            \"shared_optimizer\",\n        ]\n\n    @classmethod\n    def get_metrics(cls) -> list[str]:\n        return [\n            \"ELBO\",\n            \"KLD\",\n            \"ReconstructionLoss\",\n            \"CategoricalLoss\",\n            \"NumericalLoss\",\n        ]\n\n    def reconstruct(self, X):\n        mu_z, logsigma_z = self.encoder(X)\n        x_recon = self.decoder(mu_z)\n        return x_recon\n\n    def generate(self, N: Optional[int] = None) -> pd.DataFrame:\n        N = N or self.nrows\n        z_samples = torch.randn_like(torch.ones((N, self.encoder.latent_dim)), device=self.device)\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n            x_gen = self.decoder(z_samples)\n        x_gen_ = torch.ones_like(x_gen, device=self.device)\n\n        if self.multi_column_indices != [[]]:\n            for cat_idxs in self.multi_column_indices:\n                x_gen_[:, cat_idxs] = torch.distributions.one_hot_categorical.OneHotCategorical(\n                    logits=x_gen[:, cat_idxs]\n                ).sample()\n\n        x_gen_[:, self.single_column_indices] = x_gen[:, self.single_column_indices] + torch.exp(\n            self.noiser(x_gen[:, self.single_column_indices])\n        ) * torch.randn_like(x_gen[:, self.single_column_indices])\n        if torch.cuda.is_available():\n            x_gen_ = x_gen_.cpu()\n        return self.metatransformer.inverse_apply(pd.DataFrame(x_gen_.detach(), columns=self.columns))\n\n    def loss(self, X):\n        mu_z, logsigma_z = self.encoder(X)\n\n        p = Normal(torch.zeros_like(mu_z), torch.ones_like(mu_z))\n        q = Normal(mu_z, torch.exp(logsigma_z))\n\n        kld = torch.sum(torch.distributions.kl_divergence(q, p))\n\n        s = torch.randn_like(mu_z)\n        z_samples = mu_z + s * torch.exp(logsigma_z)\n\n        x_recon = self.decoder(z_samples)\n\n        categoric_loglik = 0\n\n        if self.multi_column_indices != [[]]:\n            for cat_idxs in self.multi_column_indices:\n                categoric_loglik += -torch.nn.functional.cross_entropy(\n                    x_recon[:, cat_idxs],\n                    torch.max(X[:, cat_idxs], 1)[1],\n                ).sum()\n\n        gauss_loglik = 0\n        if self.single_column_indices:\n            gauss_loglik = (\n                Normal(\n                    loc=x_recon[:, self.single_column_indices],\n                    scale=torch.exp(self.noiser(x_recon[:, self.single_column_indices])),\n                )\n                .log_prob(X[:, self.single_column_indices])\n                .sum()\n            )\n\n        reconstruction_loss = -(categoric_loglik + gauss_loglik)\n\n        elbo = kld + reconstruction_loss\n\n        return {\n            \"ELBO\": elbo / X.size()[0],\n            \"ReconstructionLoss\": reconstruction_loss / X.size()[0],\n            \"KLD\": kld / X.size()[0],\n            \"CategoricalLoss\": categoric_loglik / X.size()[0],\n            \"NumericalLoss\": gauss_loglik / X.size()[0],\n        }\n\n    def train(\n        self,\n        num_epochs: int = 100,\n        patience: int = 5,\n        displayed_metrics: list[str] = [\"ELBO\"],\n    ) -> tuple[int, dict[str, list[float]]]:\n        \"\"\"\n        Train the model.\n\n        Args:\n            num_epochs: Number of epochs to train for.\n            patience: Number of epochs to wait for improvement before early stopping.\n            displayed_metrics: List of metrics to display during training.\n\n        Returns:\n            The number of epochs trained for and a dictionary of the tracked metrics.\n        \"\"\"\n        self._start_training(num_epochs, patience, displayed_metrics)\n\n        self.encoder.train()\n        self.decoder.train()\n        self.noiser.train()\n\n        for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n            for (Y_subset,) in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n                self.zero_grad()\n                with warnings.catch_warnings():\n                    warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n                    losses = self.loss(Y_subset.to(self.device))\n                losses[\"ELBO\"].backward()\n                self.step()\n                self._record_metrics(losses)\n\n            elbo = np.mean(self.metrics[\"ELBO\"][-len(self.data_loader) :])\n            if self._check_patience(epoch, elbo):\n                num_epochs = epoch + 1\n                break\n\n        self._finish_training(num_epochs)\n        return (num_epochs, self.metrics)\n
"},{"location":"reference/modules/model/models/vae/#nhssynth.modules.model.models.vae.VAE.train","title":"train(num_epochs=100, patience=5, displayed_metrics=['ELBO'])","text":"

Train the model.

Parameters:

Name Type Description Default num_epochs int

Number of epochs to train for.

100 patience int

Number of epochs to wait for improvement before early stopping.

5 displayed_metrics list[str]

List of metrics to display during training.

['ELBO']

Returns:

Type Description tuple[int, dict[str, list[float]]]

The number of epochs trained for and a dictionary of the tracked metrics.

Source code in src/nhssynth/modules/model/models/vae.py
def train(\n    self,\n    num_epochs: int = 100,\n    patience: int = 5,\n    displayed_metrics: list[str] = [\"ELBO\"],\n) -> tuple[int, dict[str, list[float]]]:\n    \"\"\"\n    Train the model.\n\n    Args:\n        num_epochs: Number of epochs to train for.\n        patience: Number of epochs to wait for improvement before early stopping.\n        displayed_metrics: List of metrics to display during training.\n\n    Returns:\n        The number of epochs trained for and a dictionary of the tracked metrics.\n    \"\"\"\n    self._start_training(num_epochs, patience, displayed_metrics)\n\n    self.encoder.train()\n    self.decoder.train()\n    self.noiser.train()\n\n    for epoch in tqdm(range(num_epochs), desc=\"Epochs\", position=len(self.stats_bars), leave=False):\n        for (Y_subset,) in tqdm(self.data_loader, desc=\"Batches\", position=len(self.stats_bars) + 1, leave=False):\n            self.zero_grad()\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"Using a non-full backward hook\")\n                losses = self.loss(Y_subset.to(self.device))\n            losses[\"ELBO\"].backward()\n            self.step()\n            self._record_metrics(losses)\n\n        elbo = np.mean(self.metrics[\"ELBO\"][-len(self.data_loader) :])\n        if self._check_patience(epoch, elbo):\n            num_epochs = epoch + 1\n            break\n\n    self._finish_training(num_epochs)\n    return (num_epochs, self.metrics)\n
"},{"location":"reference/modules/plotting/","title":"plotting","text":""},{"location":"reference/modules/plotting/io/","title":"io","text":""},{"location":"reference/modules/plotting/io/#nhssynth.modules.plotting.io.check_input_paths","title":"check_input_paths(fn_dataset, fn_typed, fn_evaluations, dir_experiment)","text":"

Sets up the input and output paths for the model files.

Parameters:

Name Type Description Default fn_dataset str

The base name of the dataset.

required fn_typed str

The name of the typed data file.

required fn_evaluations str

The name of the file containing the evaluation bundle.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, str]

The paths to the data, metadata and metatransformer files.

Source code in src/nhssynth/modules/plotting/io.py
def check_input_paths(fn_dataset: str, fn_typed: str, fn_evaluations: str, dir_experiment: Path) -> tuple[str, str]:\n    \"\"\"\n    Sets up the input and output paths for the model files.\n\n    Args:\n        fn_dataset: The base name of the dataset.\n        fn_typed: The name of the typed data file.\n        fn_evaluations: The name of the file containing the evaluation bundle.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The paths to the data, metadata and metatransformer files.\n    \"\"\"\n    fn_dataset, fn_typed, fn_evaluations = io.consistent_endings([fn_dataset, fn_typed, fn_evaluations])\n    fn_typed, fn_evaluations = io.potential_suffixes([fn_typed, fn_evaluations], fn_dataset)\n    io.warn_if_path_supplied([fn_dataset, fn_typed, fn_evaluations], dir_experiment)\n    io.check_exists([fn_typed], dir_experiment)\n    return fn_dataset, fn_typed, fn_evaluations\n
"},{"location":"reference/modules/plotting/io/#nhssynth.modules.plotting.io.load_required_data","title":"load_required_data(args, dir_experiment)","text":"

Loads the data from args or from disk when the dataloader has not be run previously.

Parameters:

Name Type Description Default args Namespace

The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.

required dir_experiment Path

The path to the experiment directory.

required

Returns:

Type Description tuple[str, DataFrame, DataFrame, dict[str, dict[str, Any]]]

The data, metadata and metatransformer.

Source code in src/nhssynth/modules/plotting/io.py
def load_required_data(\n    args: argparse.Namespace, dir_experiment: Path\n) -> tuple[str, pd.DataFrame, pd.DataFrame, dict[str, dict[str, Any]]]:\n    \"\"\"\n    Loads the data from `args` or from disk when the dataloader has not be run previously.\n\n    Args:\n        args: The arguments passed to the module, in this case potentially carrying the outputs of the dataloader module.\n        dir_experiment: The path to the experiment directory.\n\n    Returns:\n        The data, metadata and metatransformer.\n    \"\"\"\n    if all(x in args.module_handover for x in [\"dataset\", \"typed\", \"evaluations\"]):\n        return (\n            args.module_handover[\"dataset\"],\n            args.module_handover[\"typed\"],\n            args.module_handover[\"evaluations\"],\n        )\n    else:\n        fn_dataset, fn_typed, fn_evaluations = check_input_paths(\n            args.dataset, args.typed, args.evaluations, dir_experiment\n        )\n\n        with open(dir_experiment / fn_typed, \"rb\") as f:\n            real_data = pickle.load(f)\n        with open(dir_experiment / fn_evaluations, \"rb\") as f:\n            evaluations = pickle.load(f)\n\n        return fn_dataset, real_data, evaluations\n
"},{"location":"reference/modules/plotting/plots/","title":"plots","text":""},{"location":"reference/modules/plotting/plots/#nhssynth.modules.plotting.plots.factorize_all_categoricals","title":"factorize_all_categoricals(df)","text":"

Factorize all categorical columns in a dataframe.

Source code in src/nhssynth/modules/plotting/plots.py
def factorize_all_categoricals(\n    df: pd.DataFrame,\n) -> pd.DataFrame:\n    \"\"\"Factorize all categorical columns in a dataframe.\"\"\"\n    for col in df.columns:\n        if df[col].dtype == \"object\":\n            df[col] = pd.factorize(df[col])[0]\n        elif df[col].dtype == \"datetime64[ns]\":\n            df[col] = pd.to_numeric(df[col])\n        min_val = df[col].min()\n        max_val = df[col].max()\n        df[col] = (df[col] - min_val) / (max_val - min_val)\n\n    return df\n
"},{"location":"reference/modules/plotting/run/","title":"run","text":""},{"location":"reference/modules/structure/","title":"structure","text":""},{"location":"reference/modules/structure/run/","title":"run","text":""}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 12e3f836cc9c4a742b1aa4ad5e74556d8a853764..d07dee109de290162faaf4754a738ba4a2e0f996 100644 GIT binary patch delta 12 Tcmb=gXOr*d;1JWF$W{pe6zT&M delta 12 Tcmb=gXOr*d;3&|W$W{pe7N`Sl