diff --git a/.gitignore b/.gitignore index b6e4761..a477fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea +/.idea +.idea/ \ No newline at end of file diff --git a/docs/AIM-Strong_Sign_Norm-01_Colors.svg b/docs/strong_ai_in_industry_logo.svg similarity index 100% rename from docs/AIM-Strong_Sign_Norm-01_Colors.svg rename to docs/strong_ai_in_industry_logo.svg diff --git a/eXNN/bayes/api.py b/eXNN/bayes/api.py index d2910bb..fb2a053 100755 --- a/eXNN/bayes/api.py +++ b/eXNN/bayes/api.py @@ -1,6 +1,5 @@ from typing import Dict, Optional -import torch import torch.optim from eXNN.bayes.wrapper import create_dropout_bayesian_wrapper @@ -8,12 +7,12 @@ class DropoutBayesianWrapper: def __init__( - self, - model: torch.nn.Module, - mode: str, - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, + self, + model: torch.nn.Module, + mode: str, + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, ): """Class representing bayesian equivalent of a neural network. @@ -49,9 +48,9 @@ def predict(self, data, n_iter) -> Dict[str, torch.Tensor]: class GaussianBayesianWrapper: def __init__( - self, - model: torch.nn.Module, - sigma: float, + self, + model: torch.nn.Module, + sigma: float, ): """Class representing bayesian equivalent of a neural network. diff --git a/eXNN/bayes/wrapper.py b/eXNN/bayes/wrapper.py index 8e1ed43..456b0c6 100644 --- a/eXNN/bayes/wrapper.py +++ b/eXNN/bayes/wrapper.py @@ -3,18 +3,31 @@ import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional from torch.distributions import Beta class ModuleBayesianWrapper(nn.Module): + """ + A wrapper for neural network layers to apply Bayesian-style dropout or noise during training. + + Args: + layer (nn.Module): The layer to wrap (e.g., nn.Linear, nn.Conv2d). + p (Optional[float]): Dropout probability for simple dropout. + Mutually exclusive with `a`, `b`, and `sigma`. + a (Optional[float]): Alpha parameter for Beta distribution dropout. Used with `b`. + b (Optional[float]): Beta parameter for Beta distribution dropout. Used with `a`. + sigma (Optional[float]): Standard deviation for Gaussian noise. + Mutually exclusive with `p`, `a`, and `b`. + """ + def __init__( - self, - layer: nn.Module, - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, - sigma: Optional[float] = None, + self, + layer: nn.Module, + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, + sigma: Optional[float] = None, ): super(ModuleBayesianWrapper, self).__init__() @@ -36,6 +49,16 @@ def __init__( self.p, self.a, self.b, self.sigma = p, a, b, sigma def augment_weights(self, weights, bias): + """ + Apply the specified noise or dropout to the weights and bias. + + Args: + weights (torch.Tensor): The weights of the layer. + bias (torch.Tensor): The bias of the layer (can be None). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The augmented weights and bias. + """ # Check if dropout is chosen if (self.p is not None) or (self.a is not None and self.b is not None): @@ -45,10 +68,10 @@ def augment_weights(self, weights, bias): else: p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample() - weights = F.dropout(weights, p, training=True) + weights = functional.dropout(weights, p, training=True) if bias is not None: # In layers we sometimes have the ability to set bias to None - bias = F.dropout(bias, p, training=True) + bias = functional.dropout(bias, p, training=True) else: # If gauss is chosen, then apply it @@ -60,11 +83,20 @@ def augment_weights(self, weights, bias): return weights, bias def forward(self, x): + """ + Forward pass through the layer with augmented weights. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ weight, bias = self.augment_weights(self.layer.weight, self.layer.bias) if isinstance(self.layer, nn.Linear): - return F.linear(x, weight, bias) + return functional.linear(x, weight, bias) elif type(self.layer) in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: return self.layer._conv_forward(x, weight, bias) else: @@ -72,6 +104,18 @@ def forward(self, x): def replace_modules_with_wrapper(model, wrapper_module, params): + """ + Recursively replaces layers in a model with a Bayesian wrapper. + + Args: + model (nn.Module): The model containing layers to replace. + wrapper_module (type): The wrapper class. + params (dict): Parameters for the wrapper. + + Returns: + nn.Module: The model with wrapped layers. + """ + if type(model) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]: return wrapper_module(model, **params) @@ -88,12 +132,26 @@ def replace_modules_with_wrapper(model, wrapper_module, params): class NetworkBayes(nn.Module): + """ + Bayesian network with standard dropout. + + Args: + model (nn.Module): The base model. + dropout_p (float): Dropout probability. + """ + def __init__( - self, - model: nn.Module, - dropout_p: float, + self, + model: nn.Module, + dropout_p: float, ): + """ + Initialize the NetworkBayes with standard dropout. + Args: + model (nn.Module): The base model to wrap with Bayesian dropout. + dropout_p (float): Dropout probability for the Bayesian wrapper. + """ super(NetworkBayes, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -103,11 +161,21 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -125,13 +193,29 @@ def mean_forward( # calculate mean and std after applying bayesian with beta distribution class NetworkBayesBeta(nn.Module): + """ + Bayesian network with Beta distribution dropout. + + Args: + model (nn.Module): The base model. + alpha (float): Alpha parameter for the Beta distribution. + beta (float): Beta parameter for the Beta distribution. + """ + def __init__( - self, - model: torch.nn.Module, - alpha: float, - beta: float, + self, + model: torch.nn.Module, + alpha: float, + beta: float, ): - + """ + Initialize the NetworkBayesBeta with Beta distribution dropout. + + Args: + model (nn.Module): The base model to wrap with Bayesian Beta dropout. + alpha (float): Alpha parameter of the Beta distribution. + beta (float): Beta parameter of the Beta distribution. + """ super(NetworkBayesBeta, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -141,11 +225,21 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -163,12 +257,26 @@ def mean_forward( class NetworkBayesGauss(nn.Module): + """ + Bayesian network with Gaussian noise. + + Args: + model (nn.Module): The base model. + sigma (float): Standard deviation of the Gaussian noise. + """ + def __init__( - self, - model: torch.nn.Module, - sigma: float, + self, + model: torch.nn.Module, + sigma: float, ): + """ + Initialize the NetworkBayesGauss with Gaussian noise. + Args: + model (nn.Module): The base model to wrap with Bayesian Gaussian noise. + sigma (float): Standard deviation of the Gaussian noise to apply. + """ super(NetworkBayesGauss, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -178,11 +286,21 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -200,13 +318,28 @@ def mean_forward( def create_dropout_bayesian_wrapper( - model: torch.nn.Module, - mode: Optional[str] = "basic", - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, - sigma: Optional[float] = None, + model: torch.nn.Module, + mode: Optional[str] = "basic", + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, + sigma: Optional[float] = None, ) -> torch.nn.Module: + """ + Creates a Bayesian network with the specified dropout mode. + + Args: + model (nn.Module): The base model. + mode (str): The dropout mode ("basic", "beta", "gauss"). + p (Optional[float]): Dropout probability for "basic" mode. + a (Optional[float]): Alpha parameter for "beta" mode. + b (Optional[float]): Beta parameter for "beta" mode. + sigma (Optional[float]): Standard deviation for "gauss" mode. + + Returns: + nn.Module: The Bayesian network. + """ + if mode == "basic": net = NetworkBayes(model, p) @@ -216,4 +349,7 @@ def create_dropout_bayesian_wrapper( elif mode == 'gauss': net = NetworkBayesGauss(model, sigma) + else: + raise ValueError("Mode should be one of ('basic', 'beta', 'gauss').") + return net diff --git a/eXNN/topology/api.py b/eXNN/topology/api.py index 70e70cb..36d4433 100644 --- a/eXNN/topology/api.py +++ b/eXNN/topology/api.py @@ -10,21 +10,22 @@ def get_data_barcode( data: torch.Tensor, hom_type: str, - coefs_type: str, + coefficient_type: str, ) -> Dict[str, np.ndarray]: - """This function computes persistent homologies for a cloud of points as barcodes. + """ + Computes persistent homologies for a cloud of points as barcodes. Args: - data (torch.Tensor): input data of shape NxC1x...xCk, + data (torch.Tensor): Input data of shape NxC1x...xCk, where N is the number of data points, - C1,...,Ck are dimensions of each data point - hom_type (str): homotopy type - coefs_type (str): coefficients type + C1,...,Ck are dimensions of each data point. + hom_type (str): Homotopy type. + coefficient_type (str): Coefficients type. Returns: - Dict[str, np.ndarray]: barcode + Dict[str, np.ndarray]: Barcode. """ - return homologies.compute_data_barcode(data, hom_type, coefs_type) + return homologies.compute_data_barcode(data, hom_type, coefficient_type) def get_nn_barcodes( @@ -32,40 +33,39 @@ def get_nn_barcodes( data: torch.Tensor, layers: List[str], hom_type: str, - coefs_type: str, + coefficient_type: str, ) -> Dict[str, Dict[str, np.ndarray]]: """ - The function plots persistent homologies for latent representations - on different levels of the neural network as barcodes. + Computes persistent homologies for latent representations on different + levels of the neural network as barcodes. Args: - model (torch.nn.Module): neural network - data (torch.Tensor): input data of shape NxC1x...xCk, + model (torch.nn.Module): Neural network. + data (torch.Tensor): Input data of shape NxC1x...xCk, where N is the number of data points, - C1,...,Ck are dimensions of each data point - layers (List[str]): list of layers for visualization. Defaults to None. - If None, visualization for all layers is performed - hom_type (str): homotopy type - coefs_type (str): coefficients type + C1,...,Ck are dimensions of each data point. + layers (List[str]): List of layers for visualization. + hom_type (str): Homotopy type. + coefficient_type (str): Coefficients type. Returns: - Dict[str, Dict[str, np.ndarray]]: dictionary with a barcode for each layer + Dict[str, Dict[str, np.ndarray]]: Dictionary with a barcode for each layer. """ res = {} for layer in layers: - res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefs_type) + res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefficient_type) return res def plot_barcode(barcode: Dict[str, np.ndarray]) -> matplotlib.figure.Figure: """ - The function creates a plot of a persistent homologies barcode. + Creates a plot of a persistent homologies barcode. Args: - barcode (Dict[str, np.ndarray]): barcode + barcode (Dict[str, np.ndarray]): Barcode. Returns: - matplotlib.figure.Figure: a plot of the barcode + matplotlib.figure.Figure: Plot of the barcode. """ return homologies.plot_barcode(barcode) @@ -74,15 +74,15 @@ def evaluate_barcode( barcode: Dict[str, np.ndarray], metric_name: Optional[str] = None ) -> Union[float, Dict[str, float]]: """ - The function evaluates a persistent homologies barcode with a metric. + Evaluates a persistent homologies barcode with a metric. Args: - barcode (Dict[str, np.ndarray]): barcode - metric_name (Optional[str]): metric name - (if `None` all available metrics values are computed) + barcode (Dict[str, np.ndarray]): Barcode. + metric_name (Optional[str]): Metric name (if None, all available metrics + values are computed). Returns: - Union(float, Dict[str, float]): float if metric is specified - or a dictionary with value of each available metric + Union[float, Dict[str, float]]: Float if metric is specified, or a + dictionary with values of each available metric. """ return metrics.compute_metric(barcode, metric_name) diff --git a/eXNN/topology/homologies.py b/eXNN/topology/homologies.py index 6137ed7..19f470d 100644 --- a/eXNN/topology/homologies.py +++ b/eXNN/topology/homologies.py @@ -10,121 +10,148 @@ ) -def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str): +def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str) -> torch.Tensor: + """ + Extracts the activation output of a specified layer from a given model. + + Args: + model (torch.nn.Module): The neural network model. + x (torch.Tensor): The input tensor to the model. + layer (str): The name of the layer to extract activation from. + + Returns: + torch.Tensor: The activation output of the specified layer. + """ activation = {} def store_output(name): - def hook(model, input, output): + def hook(_, __, output): activation[name] = output.detach() return hook - h1 = getattr(model, layer).register_forward_hook(store_output(layer)) + hook_handle = getattr(model, layer).register_forward_hook(store_output(layer)) model.forward(x) - h1.remove() + hook_handle.remove() return activation[layer] -def _diagram_to_barcode(plot): +def _diagram_to_barcode(plot) -> Dict[str, np.ndarray]: + """ + Converts a persistence diagram plot into a barcode representation. + + Args: + plot: The plot object containing persistence diagram data. + + Returns: + Dict[str, np.ndarray]: A dictionary where keys are homology types, and + values are arrays of intervals. + """ data = plot["data"] homologies = {} - for h in data: - if h["name"] is None: + for homology in data: + if homology["name"] is None: continue - homologies[h["name"]] = list(zip(h["x"], h["y"])) + homologies[homology["name"]] = list(zip(homology["x"], homology["y"])) - for h in homologies.keys(): - homologies[h] = sorted(homologies[h], key=lambda x: x[0]) + for key in homologies.keys(): + homologies[key] = sorted(homologies[key], key=lambda x: x[0]) return homologies -def plot_barcode(barcode: Dict[str, np.ndarray]): +def plot_barcode(barcode: Dict[str, np.ndarray]) -> plt.Figure: + """ + Plots a barcode diagram from given intervals. + + Args: + barcode (Dict[str, np.ndarray]): A dictionary containing homology types and their intervals. + + Returns: + plt.Figure: The matplotlib figure containing the barcode diagram. + """ homologies = list(barcode.keys()) - nplots = len(homologies) - fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots))) + np_lots = len(homologies) + fig, axes = plt.subplots(np_lots, figsize=(15, min(10, 5 * np_lots))) - if nplots == 1: - ax = [ax] + if np_lots == 1: + axes = [axes] - for i in range(nplots): - name = homologies[i] - ax[i].set_title(name) - ax[i].set_ylim([-0.05, 1.05]) + for i, name in enumerate(homologies): + axes[i].set_title(name) + axes[i].set_ylim([-0.05, 1.05]) bars = barcode[name] n = len(bars) - for j in range(n): - bar = bars[j] - ax[i].plot([bar[0], bar[1]], [j / n, j / n], color="black") - labels = ["" for _ in range(len(ax[i].get_yticklabels()))] - ax[i].set_yticks(ax[i].get_yticks()) - ax[i].set_yticklabels(labels) - - if nplots == 1: - ax = ax[0] + for j, bar in enumerate(bars): + axes[i].plot([bar[0], bar[1]], [j / n, j / n], color="black") + axes[i].set_yticks([]) + plt.close(fig) return fig -def compute_data_barcode(data: torch.Tensor, hom_type: str, coefs_type: str): +def compute_data_barcode( + data: torch.Tensor, + hom_type: str, + coefficient_type: str +) -> Dict[str, np.ndarray]: + """ + Computes a barcode for the given data using persistent homology. + + Args: + data (torch.Tensor): The input data for barcode computation. + hom_type (str): The type of homology to use ("standard", "sparse", "weak"). + coefficient_type (str): The coefficient field for homology computation. + + Returns: + Dict[str, np.ndarray]: The computed barcode as a dictionary. + + Raises: + ValueError: If an invalid hom_type is provided. + """ if hom_type == "standard": - VR = VietorisRipsPersistence( + vr = VietorisRipsPersistence( homology_dimensions=[0], collapse_edges=True, - coeff=int(coefs_type), + coeff=int(coefficient_type), ) elif hom_type == "sparse": - VR = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefs_type)) + vr = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefficient_type)) elif hom_type == "weak": - VR = WeakAlphaPersistence( + vr = WeakAlphaPersistence( homology_dimensions=[0], collapse_edges=True, - coeff=int(coefs_type), + coeff=int(coefficient_type), ) else: - raise Exception('hom_type must be one of: "standard", "sparse", "weak"!') + raise ValueError('hom_type must be one of: "standard", "sparse", "weak"!') - if len(data.shape) > 2: + if data.ndim > 2: data = torch.nn.Flatten()(data) data = data.reshape(1, *data.shape) - diagrams = VR.fit_transform(data) - plot = VR.plot(diagrams) + diagrams = vr.fit_transform(data) + plot = vr.plot(diagrams) return _diagram_to_barcode(plot) def compute_nn_barcode( - model: torch.nn.Module, - x: torch.Tensor, - layer: str, - hom_type: str, - coefs_type: str, -): - act = _get_activation(model, x, layer) - return compute_data_barcode(act, hom_type, coefs_type) - - -# def get_homologies_experimental( -# model: torch.nn.Module, -# x: torch.Tensor, -# layer: str, -# dimensions: Optional[List[int]] = None, -# make_barplot: bool = True, -# rm_empty: bool = True, -# ): -# act = _get_activation(model, x, layer) -# act = act.reshape(1, *act.shape) -# # Dimensions must not be outside layer dimensionality -# N = act.shape[-1] -# dimensions = dimensions if dimensions is not None else [] -# dimensions = [i if i >= 0 else N + i for i in dimensions] -# dimensions = [i for i in dimensions if ((i >= 0) and (i < N))] -# dimensions = list(set(dimensions)) -# VR = VietorisRipsPersistence(homology_dimensions=dimensions, collapse_edges=True) -# diagrams = VR.fit_transform(act) -# plot = VR.plot(diagrams) -# if make_barplot: -# barcode = _diagram_to_barcode(plot) -# if rm_empty: -# barcode = {key: val for key, val in barcode.items() if len(val) > 0} -# return _plot_barcode(barcode) -# else: -# return plot + model: torch.nn.Module, + x: torch.Tensor, + layer: str, + hom_type: str, + coefficient_type: str, +) -> Dict[str, np.ndarray]: + """ + Computes a barcode for a specified layer in a neural network model. + + Args: + model (torch.nn.Module): The neural network model. + x (torch.Tensor): The input data for the model. + layer (str): The layer to extract activations from. + hom_type (str): The type of homology to use ("standard", "sparse", "weak"). + coefficient_type (str): The coefficient field for homology computation. + + Returns: + Dict[str, np.ndarray]: The computed barcode as a dictionary. + """ + activation = _get_activation(model, x, layer) + return compute_data_barcode(activation, hom_type, coefficient_type) diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py index 7b5f71a..6013812 100644 --- a/eXNN/topology/metrics.py +++ b/eXNN/topology/metrics.py @@ -1,25 +1,32 @@ import heapq +from typing import Dict import numpy as np def _get_available_metrics(): + """ + Returns a dictionary mapping metric names to their respective computation functions. + + Returns: + Dict[str, callable]: A dictionary of metric computation functions. + """ return { - # absolute length based metrics + # Absolute length-based metrics "max_length": _compute_longest_interval_metric, "mean_length": _compute_length_mean_metric, "median_length": _compute_length_median_metric, "stdev_length": _compute_length_stdev_metric, "sum_length": _compute_length_sum_metric, - # relative length based metrics + # Relative length-based metrics "ratio_2_1": _compute_two_to_one_ratio_metric, "ratio_3_1": _compute_three_to_one_ratio_metric, - # entopy based metrics + # Entropy-based metrics "h": _compute_entropy_metric, "normh": _compute_normed_entropy_metric, - # signal to noise ration + # Signal-to-noise ratio "snr": _compute_snr_metric, - # birth-death based metrics + # Birth-death based metrics "mean_birth": _compute_births_mean_metric, "stdev_birth": _compute_births_stdev_metric, "mean_death": _compute_deaths_mean_metric, @@ -27,120 +34,275 @@ def _get_available_metrics(): } -def compute_metric(barcode, metric_name=None): +def compute_metric(barcode: Dict[str, np.ndarray], metric_name: str = None): + """ + Computes specified or all metrics for a given barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to compute metrics for. + metric_name (str, optional): The specific metric name to compute. + If None, all metrics are computed. + + Returns: + float or Dict[str, float]: The computed metric(s). + """ metrics = _get_available_metrics() if metric_name is None: - return {name: fn(barcode) for (name, fn) in metrics.items()} - else: - return metrics[metric_name](barcode) + return {name: fn(barcode) for name, fn in metrics.items()} + return metrics[metric_name](barcode) + + +def _get_lengths(barcode: Dict[str, np.ndarray]): + """ + Extracts lengths of intervals from a barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _get_lengths(barcode): + Returns: + List[float]: A list of interval lengths. + """ diag = barcode["H0"] return [d[1] - d[0] for d in diag] -def _compute_longest_interval_metric(barcode): +def _compute_longest_interval_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the maximum interval length in the barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The maximum interval length. + """ lengths = _get_lengths(barcode) - return np.max(lengths).item() + return float(np.max(lengths)) + +def _compute_length_mean_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the mean interval length in the barcode. -def _compute_length_mean_metric(barcode): + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The mean interval length. + """ lengths = _get_lengths(barcode) - return np.mean(lengths).item() + return float(np.mean(lengths)) + +def _compute_length_median_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the median interval length in the barcode. -def _compute_length_median_metric(barcode): + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The median interval length. + """ lengths = _get_lengths(barcode) - return np.median(lengths).item() + return float(np.median(lengths)) + + +def _compute_length_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the standard deviation of interval lengths in the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _compute_length_stdev_metric(barcode): + Returns: + float: The standard deviation of interval lengths. + """ lengths = _get_lengths(barcode) - return np.std(lengths).item() + return float(np.std(lengths)) + + +def _compute_length_sum_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the sum of all interval lengths in the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _compute_length_sum_metric(barcode): + Returns: + float: The sum of all interval lengths. + """ lengths = _get_lengths(barcode) - return np.sum(lengths).item() + return float(np.sum(lengths)) -# Proportion between the longest intervals: 2/1 ratio, 3/1 ratio -def _compute_two_to_one_ratio_metric(barcode): +def _compute_two_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the ratio of the second largest to the largest interval length. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The ratio of the second largest to the largest interval length. + """ lengths = _get_lengths(barcode) value = heapq.nlargest(2, lengths)[1] / lengths[0] - return value.item() + return float(value) + +def _compute_three_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the ratio of the third largest to the largest interval length. -def _compute_three_to_one_ratio_metric(barcode): + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The ratio of the third largest to the largest interval length. + """ lengths = _get_lengths(barcode) value = heapq.nlargest(3, lengths)[2] / lengths[0] - return value.item() + return float(value) + +def _get_entropy(values: np.ndarray, normalize: bool) -> float: + """ + Computes the entropy of a given distribution. -# Compute the persistent entropy and normed persistent entropy -def _get_entropy(values, normalize: bool): + Args: + values (np.ndarray): The values to compute entropy for. + normalize (bool): Whether to normalize the entropy. + + Returns: + float: The computed entropy. + """ values_sum = np.sum(values) entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) if normalize: entropy = entropy / np.log(values_sum) - return entropy + return float(entropy) + + +def _compute_entropy_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the persistent entropy of intervals in the barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The persistent entropy. + """ + return _get_entropy(_get_lengths(barcode), normalize=False) -def _compute_entropy_metric(barcode): - return _get_entropy(_get_lengths(barcode), normalize=False).item() +def _compute_normed_entropy_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the normalized persistent entropy of intervals in the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _compute_normed_entropy_metric(barcode): - return _get_entropy(_get_lengths(barcode), normalize=True).item() + Returns: + float: The normalized persistent entropy. + """ + return _get_entropy(_get_lengths(barcode), normalize=True) -# Compute births -def _get_births(barcode): +def _get_births(barcode: Dict[str, np.ndarray]) -> np.ndarray: + """ + Extracts the birth times from the barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + np.ndarray: An array of birth times. + """ diag = barcode["H0"] return np.array([x[0] for x in diag]) -# Comput deaths -def _get_deaths(barcode): - diag = barcode["H0"] - return np.array([x[1] for x in diag]) +def _get_deaths(barcode: Dict[str, np.ndarray]) -> np.ndarray: + """ + Extracts the death times from the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -# def _get_birth(barcode, dim): -# diag = barcode['H0'] -# temp = np.array([x[0] for x in diag if x[2] == dim]) -# return temp[0] + Returns: + np.ndarray: An array of death times. + """ + diag = barcode["H0"] + return np.array([x[1] for x in diag]) -# def _get_death(barcode, dim): -# diag = barcode['H0'] -# temp = np.array([x[1] for x in diag if x[2] == dim]) -# return temp[-1] +def _compute_snr_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the signal-to-noise ratio (SNR) for the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -# Compute SNR -def _compute_snr_metric(barcode): + Returns: + float: The computed SNR. + """ births = _get_births(barcode) deaths = _get_deaths(barcode) signal = np.mean(deaths - births) noise = np.std(births) snr = signal / noise - return snr.item() + return float(snr) + + +def _compute_births_mean_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the mean of birth times in the barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The mean of birth times. + """ + return float(np.mean(_get_births(barcode))) + + +def _compute_births_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the standard deviation of birth times in the barcode. + + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. + + Returns: + float: The standard deviation of birth times. + """ + return float(np.std(_get_births(barcode))) -# Compute the birth-death pair indices: Birth mean, birth stdev, death mean, death stdev -def _compute_births_mean_metric(barcode): - return np.mean(_get_births(barcode)).item() +def _compute_deaths_mean_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the mean of death times in the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _compute_births_stdev_metric(barcode): - return np.std(_get_births(barcode)).item() + Returns: + float: The mean of death times. + """ + return float(np.mean(_get_deaths(barcode))) -def _compute_deaths_mean_metric(barcode): - return np.mean(_get_deaths(barcode)).item() +def _compute_deaths_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: + """ + Computes the standard deviation of death times in the barcode. + Args: + barcode (Dict[str, np.ndarray]): The barcode to process. -def _compute_deaths_stdev_metric(barcode): - return np.std(_get_deaths(barcode)).item() + Returns: + float: The standard deviation of death times. + """ + return float(np.std(_get_deaths(barcode))) diff --git a/tests/test_bayesian/__init__.py b/tests/test_bayesian/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py new file mode 100644 index 0000000..7861998 --- /dev/null +++ b/tests/test_bayesian/test_bayes.py @@ -0,0 +1,72 @@ +import torch + +import eXNN.bayes as bayes_api +import tests.utils.test_utils as utils + + +def _test_bayes_prediction(mode: str): + """ + Helper function to test Bayesian wrappers with different configurations. + + Args: + mode (str): Bayesian wrapper mode ("basic", "beta", "gauss"). + + Verifies: + - Output is a dictionary with keys "mean" and "std". + - Shapes of "mean" and "std" match the expected shape. + """ + params = { + "basic": dict(mode="basic", p=0.5), + "beta": dict(mode="beta", a=0.9, b=0.2), + "gauss": dict(sigma=1e-2), + } + + n, dim, data = utils.create_testing_data() + num_classes = 17 + model = utils.create_testing_model(num_classes=num_classes) + n_iter = 7 + + if mode != 'gauss': + res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + else: + res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, + n_iter=n_iter) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values({"mean", "std"}, set(res.keys()), "Wrong dictionary keys") + utils.compare_values(torch.Size([n, num_classes]), res["mean"].shape, + "Wrong mean shape") + utils.compare_values(torch.Size([n, num_classes]), res["std"].shape, + "Wrong mean std") + + +def test_basic_bayes_wrapper(): + """ + Test the basic Bayesian wrapper. + """ + _test_bayes_prediction("basic") + + +def test_beta_bayes_wrapper(): + """ + Test the beta Bayesian wrapper. + """ + _test_bayes_prediction("beta") + + +def test_gauss_bayes_wrapper(): + """ + Test the Gaussian Bayesian wrapper. + """ + _test_bayes_prediction("gauss") + + +def test_all_bayes_wrappers(): + """ + Test all Bayesian wrappers (basic, beta, and gauss) with a single general test. + """ + modes = ["basic", "beta", "gauss"] + + for mode in modes: + _test_bayes_prediction(mode) diff --git a/tests/test_topology/__init__.py b/tests/test_topology/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py new file mode 100644 index 0000000..22f5b57 --- /dev/null +++ b/tests/test_topology/test_barcode_general.py @@ -0,0 +1,108 @@ +import matplotlib + +import eXNN.topology as topology_api +import tests.utils.test_utils as utils + + +def test_data_barcode(): + """ + Test generating a barcode from data. + + Verifies: + - Output is a dictionary. + """ + n, dim, data = utils.create_testing_data() + res = topology_api.get_data_barcode(data, "standard", "3") + utils.compare_values(dict, type(res), "Wrong result type") + + +def test_nn_barcodes(): + """ + Test generating barcodes for a neural network. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the specified layers. + - Each dictionary value is a dictionary. + """ + n, dim, data = utils.create_testing_data() + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + + res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3") + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys") + + for layer, barcode in res.items(): + utils.compare_values( + dict, + type(barcode), + f"Wrong result type for key {layer}", + ) + + +def test_barcode_plot(): + """ + Test generating a barcode plot. + + Verifies: + - Output is a matplotlib.figure.Figure. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + plot = topology_api.plot_barcode(barcode) + utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type") + + +def test_all_barcodes(): + """ + Test all barcode-related functions (data barcode, NN barcodes, and barcode plot) together. + + Verifies: + - Data barcode is generated correctly. + - Neural network barcodes are generated correctly for specified layers. + - Barcode plot is generated correctly. + """ + # Test data barcode + n, dim, data = utils.create_testing_data() + res = topology_api.get_data_barcode(data, "standard", "3") + utils.compare_values(dict, type(res), "Wrong result type for data barcode") + + # Test NN barcodes + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + nn_barcodes = topology_api.get_nn_barcodes( + model, + data, + layers, + "standard", + "3" + ) + + utils.compare_values( + dict, + type(nn_barcodes), + "Wrong result type for NN barcodes" + ) + + utils.compare_values( + 2, + len(nn_barcodes), + "Wrong dictionary length for NN barcodes" + ) + + utils.compare_values( + set(layers), + set(nn_barcodes.keys()), + "Wrong dictionary keys for NN barcodes" + ) + + # Test barcode plot + barcode_plot = topology_api.plot_barcode(res) + utils.compare_values( + matplotlib.figure.Figure, + type(barcode_plot), + "Wrong result type for barcode plot" + ) diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py new file mode 100644 index 0000000..0a77073 --- /dev/null +++ b/tests/test_topology/test_barcode_metrics.py @@ -0,0 +1,98 @@ +import eXNN.topology as topology_api +import tests.utils.test_utils as utils + + +def test_barcode_evaluate_all_metrics(): + """ + Test evaluating all metrics for a barcode. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the expected metric names. + - Each metric value is a float. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode) + utils.compare_values(dict, type(result), "Wrong result type") + all_metric_names = [ + "h", + "max_length", + "mean_birth", + "mean_death", + "mean_length", + "median_length", + "normh", + "ratio_2_1", + "ratio_3_1", + "snr", + "stdev_birth", + "stdev_death", + "stdev_length", + "sum_length", + ] + utils.compare_values(all_metric_names, sorted(result.keys())) + for name, value in result.items(): + utils.compare_values(float, type(value), f"Wrong result type for metric {name}") + + +def test_barcode_evaluate_one_metric(): + """ + Test evaluating a single metric for a barcode. + + Verifies: + - Output is a float. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode, metric_name="mean_length") + utils.compare_values(float, type(result), "Wrong result type") + + +def test_barcode_evaluate_all_metrics_and_individual(): + """ + Test evaluating all metrics for a barcode and individual metrics evaluation. + + Verifies: + - Output for evaluating all metrics is a dictionary. + - Dictionary keys match the expected metric names. + - Each metric value is a float. + - Individual metrics can be correctly evaluated. + """ + # Test for all metrics + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode) + + # Check that the result is a dictionary + utils.compare_values(dict, type(result), "Wrong result type") + + # List of expected metric names + all_metric_names = [ + "h", "max_length", "mean_birth", "mean_death", "mean_length", + "median_length", "normh", "ratio_2_1", "ratio_3_1", "snr", + "stdev_birth", "stdev_death", "stdev_length", "sum_length" + ] + + # Check that the dictionary keys match the expected metric names + utils.compare_values( + all_metric_names, + sorted(result.keys()), + "Wrong dictionary keys" + ) + + # Ensure all metric values are floats + for name, value in result.items(): + utils.compare_values( + float, type(value), + f"Wrong result type for metric {name}" + ) + + # Test for evaluating individual metrics + for metric_name in all_metric_names: + individual_result = topology_api.evaluate_barcode(barcode, metric_name=metric_name) + utils.compare_values( + float, + type(individual_result), + f"Wrong result type for individual metric {metric_name}" + ) diff --git a/tests/test_visualization/__init__.py b/tests/test_visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_visualization/test_check_random_input.py b/tests/test_visualization/test_check_random_input.py new file mode 100644 index 0000000..5bd6360 --- /dev/null +++ b/tests/test_visualization/test_check_random_input.py @@ -0,0 +1,18 @@ +import torch + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def test_check_random_input(): + """ + Test generating random input tensors with a specific shape. + + Verifies: + - Output type is a torch.Tensor. + - Output shape matches the specified shape. + """ + shape = [5, 17, 81, 37] + data = viz_api.get_random_input(shape) + utils.compare_values(type(data), torch.Tensor, "Wrong result type") + utils.compare_values(torch.Size(shape), data.shape, "Wrong result shape") diff --git a/tests/test_visualization/test_reduce_dim.py b/tests/test_visualization/test_reduce_dim.py new file mode 100644 index 0000000..8ed32b9 --- /dev/null +++ b/tests/test_visualization/test_reduce_dim.py @@ -0,0 +1,55 @@ +import numpy as np + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def _check_reduce_dim(mode): + """ + Helper function to test dimensionality reduction with a given mode. + + Args: + mode (str): Dimensionality reduction mode (e.g., "umap", "pca"). + + Verifies: + - Output type is a numpy.ndarray. + - Output shape is (n_samples, 2). + """ + n, dim, data = utils.create_testing_data() + reduced_data = viz_api.reduce_dim(data, mode) + utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type") + utils.compare_values((n, 2), reduced_data.shape, "Wrong result shape") + + +def test_reduce_dim_umap(): + """ + Test dimensionality reduction using UMAP. + + Uses: + _check_reduce_dim with mode "umap". + """ + _check_reduce_dim("umap") + + +def test_reduce_dim_pca(): + """ + Test dimensionality reduction using PCA. + + Uses: + _check_reduce_dim with mode "pca". + """ + _check_reduce_dim("pca") + + +def test_all_reduce_dim_methods(): + """ + Test dimensionality reduction using all available methods (e.g., UMAP, PCA). + + Verifies: + - Dimensionality reduction works for each method. + - Output is of type numpy.ndarray and shape (n_samples, 2). + """ + modes = ["umap", "pca"] + + for mode in modes: + _check_reduce_dim(mode) diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py new file mode 100644 index 0000000..859f0b3 --- /dev/null +++ b/tests/test_visualization/test_visualization.py @@ -0,0 +1,69 @@ +import matplotlib +import plotly +import torch + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def test_visualization(): + """ + Test visualization of layer manifolds using UMAP. + + Verifies: + - Output is a dictionary. + - Dictionary has the correct keys corresponding to the input and specified layers. + - Each dictionary value is a matplotlib.figure.Figure. + """ + + n, dim, data = utils.create_testing_data() + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(3, len(res), "Wrong dictionary length") + utils.compare_values( + set(["input"] + layers), + set(res.keys()), + "Wrong dictionary keys", + ) + + for key, plot in res.items(): + utils.compare_values( + matplotlib.figure.Figure, + type(plot), + f"Wrong value type for key {key}", + ) + + +def test_embed_visualization(): + """ + Test visualization of recurrent layer manifolds with embeddings. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the specified layers. + - Each dictionary value is a plotly.graph_objs.Figure. + """ + data = torch.randn((20, 1, 256)) + labels = torch.randn(20) + model = utils.create_testing_model_lstm() + layers = ["second_layer", "third_layer"] + + res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, + labels=labels) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values( + set(layers), + set(res.keys()), + "Wrong dictionary keys", + ) + for key, plot in res.items(): + utils.compare_values( + plotly.graph_objs.Figure, + type(plot), + f"Wrong value type for key {key}", + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 0000000..ef0f82c --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,123 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + + +def _form_message_header(message_header=None): + """ + Forms the header for assertion messages. + + Args: + message_header (str, optional): Custom message header. Defaults to None. + + Returns: + str: The custom header if provided, otherwise "Value mismatch". + """ + return "Value mismatch" if message_header is None else message_header + + +def compare_values(expected, got, message_header=None): + """ + Compares two values and raises an assertion error if they do not match. + + Args: + expected: The expected value. + got: The value to compare against the expected value. + message_header (str, optional): Custom header for the assertion message. Defaults to None. + + Raises: + AssertionError: If the values do not match, an error with the message header is raised. + """ + assert (expected == got), \ + f"{_form_message_header(message_header)}: expected {expected}, got {got}" + + +def create_testing_data(): + """ + Creates synthetic testing data. + + Returns: + tuple: A tuple containing: + - N (int): Number of samples. + - dim (int): Dimensionality of each sample. + - data (torch.Tensor): Randomly generated data tensor of shape (N, dim). + """ + n = 20 + dim = 256 + data = torch.randn((n, dim)) + return n, dim, data + + +def create_testing_model(num_classes=10): + """ + Creates a simple feedforward neural network for testing. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 10. + + Returns: + nn.Sequential: A sequential model with three layers: + - Linear layer (input_dim=256, output_dim=128). + - Linear layer (input_dim=128, output_dim=64). + - Linear layer (input_dim=64, output_dim=num_classes). + """ + return nn.Sequential( + OrderedDict( + [ + ("first_layer", nn.Linear(256, 128)), + ("second_layer", nn.Linear(128, 64)), + ("third_layer", nn.Linear(64, num_classes)), + ], + ), + ) + + +class ExtractTensor(nn.Module): + """ + A custom PyTorch module to extract and process tensors. + + This module extracts the first tensor from a tuple, converts it to + float32, and returns its values. + """ + + @staticmethod + def forward(x): + """ + Forward pass for tensor extraction and processing. + + Args: + x (tuple): Input tuple where the first element is the tensor to be processed. + + Returns: + torch.Tensor: Processed tensor (converted to float32). + """ + tensor, _ = x + tensor = tensor.to(torch.float32) + return tensor[:, :] + + +def create_testing_model_lstm(num_classes=10): + """ + Creates a recurrent neural network with LSTM layers for testing. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 10. + + Returns: + nn.Sequential: A sequential model with the following layers: + - LSTM layer (input_dim=256, hidden_dim=128, num_layers=1). + - ExtractTensor layer to process LSTM output. + - Linear layer (input_dim=128, output_dim=64). + - Linear layer (input_dim=64, output_dim=num_classes). + """ + return nn.Sequential( + OrderedDict( + [ + ('first_layer', nn.LSTM(256, 128, 1, batch_first=True)), + ('extract', ExtractTensor()), + ('second_layer', nn.Linear(128, 64)), + ('third_layer', nn.Linear(64, num_classes)), + ], + ), + )