From e4f76a331bea5c6e1174bcca1c762d7e8a9ac0fc Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 02:18:03 +0300 Subject: [PATCH 01/16] feat: add doc strings for functions --- eXNN/topology/metrics.py | 224 +++++++++++++++++++++++++++++++++------ 1 file changed, 191 insertions(+), 33 deletions(-) diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py index 7b5f71a..d5b5111 100644 --- a/eXNN/topology/metrics.py +++ b/eXNN/topology/metrics.py @@ -1,25 +1,30 @@ import heapq - import numpy as np def _get_available_metrics(): + """ + Returns a dictionary of available metric names and their corresponding functions. + + Returns: + dict: Mapping of metric names to functions. + """ return { - # absolute length based metrics + # Absolute length-based metrics "max_length": _compute_longest_interval_metric, "mean_length": _compute_length_mean_metric, "median_length": _compute_length_median_metric, "stdev_length": _compute_length_stdev_metric, "sum_length": _compute_length_sum_metric, - # relative length based metrics + # Relative length-based metrics "ratio_2_1": _compute_two_to_one_ratio_metric, "ratio_3_1": _compute_three_to_one_ratio_metric, - # entopy based metrics + # Entropy-based metrics "h": _compute_entropy_metric, "normh": _compute_normed_entropy_metric, - # signal to noise ration + # Signal-to-noise ratio "snr": _compute_snr_metric, - # birth-death based metrics + # Birth-death based metrics "mean_birth": _compute_births_mean_metric, "stdev_birth": _compute_births_stdev_metric, "mean_death": _compute_deaths_mean_metric, @@ -28,119 +33,272 @@ def _get_available_metrics(): def compute_metric(barcode, metric_name=None): + """ + Compute the specified metric or all available metrics for a given barcode. + + Args: + barcode (dict): The barcode data containing persistent homology intervals. + metric_name (str, optional): The name of the metric to compute. Defaults to None. + + Returns: + dict or float: A dictionary of all metrics if metric_name is None, otherwise the value of the specified metric. + """ metrics = _get_available_metrics() if metric_name is None: - return {name: fn(barcode) for (name, fn) in metrics.items()} - else: - return metrics[metric_name](barcode) + return {name: fn(barcode) for name, fn in metrics.items()} + return metrics[metric_name](barcode) def _get_lengths(barcode): + """ + Compute lengths of intervals in the H0 component of the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + list: A list of interval lengths. + """ diag = barcode["H0"] return [d[1] - d[0] for d in diag] def _compute_longest_interval_metric(barcode): + """ + Compute the length of the longest interval in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The length of the longest interval. + """ lengths = _get_lengths(barcode) return np.max(lengths).item() def _compute_length_mean_metric(barcode): + """ + Compute the mean length of intervals in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The mean length of intervals. + """ lengths = _get_lengths(barcode) return np.mean(lengths).item() def _compute_length_median_metric(barcode): + """ + Compute the median length of intervals in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The median length of intervals. + """ lengths = _get_lengths(barcode) return np.median(lengths).item() def _compute_length_stdev_metric(barcode): + """ + Compute the standard deviation of interval lengths in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The standard deviation of interval lengths. + """ lengths = _get_lengths(barcode) return np.std(lengths).item() def _compute_length_sum_metric(barcode): + """ + Compute the sum of interval lengths in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The sum of interval lengths. + """ lengths = _get_lengths(barcode) return np.sum(lengths).item() -# Proportion between the longest intervals: 2/1 ratio, 3/1 ratio def _compute_two_to_one_ratio_metric(barcode): + """ + Compute the ratio of the second longest to the longest interval in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The 2-to-1 ratio of interval lengths. + """ lengths = _get_lengths(barcode) value = heapq.nlargest(2, lengths)[1] / lengths[0] - return value.item() + return value def _compute_three_to_one_ratio_metric(barcode): + """ + Compute the ratio of the third longest to the longest interval in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The 3-to-1 ratio of interval lengths. + """ lengths = _get_lengths(barcode) value = heapq.nlargest(3, lengths)[2] / lengths[0] - return value.item() + return value + +def _get_entropy(values, normalize): + """ + Compute the entropy of a set of values. -# Compute the persistent entropy and normed persistent entropy -def _get_entropy(values, normalize: bool): + Args: + values (list): The values for which to compute entropy. + normalize (bool): Whether to normalize the entropy. + + Returns: + float: The computed entropy. + """ values_sum = np.sum(values) - entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) + entropy = -np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) if normalize: - entropy = entropy / np.log(values_sum) + entropy /= np.log(values_sum) return entropy def _compute_entropy_metric(barcode): + """ + Compute the persistent entropy of the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The persistent entropy. + """ return _get_entropy(_get_lengths(barcode), normalize=False).item() def _compute_normed_entropy_metric(barcode): + """ + Compute the normalized persistent entropy of the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The normalized persistent entropy. + """ return _get_entropy(_get_lengths(barcode), normalize=True).item() -# Compute births def _get_births(barcode): + """ + Extract the birth times from the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + np.ndarray: The birth times. + """ diag = barcode["H0"] return np.array([x[0] for x in diag]) -# Comput deaths def _get_deaths(barcode): - diag = barcode["H0"] - return np.array([x[1] for x in diag]) + """ + Extract the death times from the barcode. + Args: + barcode (dict): The barcode data. -# def _get_birth(barcode, dim): -# diag = barcode['H0'] -# temp = np.array([x[0] for x in diag if x[2] == dim]) -# return temp[0] + Returns: + np.ndarray: The death times. + """ + diag = barcode["H0"] + return np.array([x[1] for x in diag]) -# def _get_death(barcode, dim): -# diag = barcode['H0'] -# temp = np.array([x[1] for x in diag if x[2] == dim]) -# return temp[-1] +def _compute_snr_metric(barcode): + """ + Compute the signal-to-noise ratio (SNR) of the barcode. + Args: + barcode (dict): The barcode data. -# Compute SNR -def _compute_snr_metric(barcode): + Returns: + float: The SNR value. + """ births = _get_births(barcode) deaths = _get_deaths(barcode) signal = np.mean(deaths - births) noise = np.std(births) - snr = signal / noise - return snr.item() + return (signal / noise).item() -# Compute the birth-death pair indices: Birth mean, birth stdev, death mean, death stdev def _compute_births_mean_metric(barcode): + """ + Compute the mean of birth times in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The mean birth time. + """ return np.mean(_get_births(barcode)).item() def _compute_births_stdev_metric(barcode): + """ + Compute the standard deviation of birth times in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The standard deviation of birth times. + """ return np.std(_get_births(barcode)).item() def _compute_deaths_mean_metric(barcode): + """ + Compute the mean of death times in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The mean death time. + """ return np.mean(_get_deaths(barcode)).item() def _compute_deaths_stdev_metric(barcode): + """ + Compute the standard deviation of death times in the barcode. + + Args: + barcode (dict): The barcode data. + + Returns: + float: The standard deviation of death times. + """ return np.std(_get_deaths(barcode)).item() From 54192d1079c33af1676485a5af59f8754e0dd729 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 02:26:41 +0300 Subject: [PATCH 02/16] feat: add doc strings, breaklines; rename some variables --- eXNN/topology/homologies.py | 170 ++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 74 deletions(-) diff --git a/eXNN/topology/homologies.py b/eXNN/topology/homologies.py index 6137ed7..e5a0196 100644 --- a/eXNN/topology/homologies.py +++ b/eXNN/topology/homologies.py @@ -10,121 +10,143 @@ ) -def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str): +def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str) -> torch.Tensor: + """ + Extracts the activation output of a specified layer from a given model. + + Args: + model (torch.nn.Module): The neural network model. + x (torch.Tensor): The input tensor to the model. + layer (str): The name of the layer to extract activation from. + + Returns: + torch.Tensor: The activation output of the specified layer. + """ activation = {} def store_output(name): - def hook(model, input, output): + def hook(_, __, output): activation[name] = output.detach() return hook - h1 = getattr(model, layer).register_forward_hook(store_output(layer)) + hook_handle = getattr(model, layer).register_forward_hook(store_output(layer)) model.forward(x) - h1.remove() + hook_handle.remove() return activation[layer] -def _diagram_to_barcode(plot): +def _diagram_to_barcode(plot) -> Dict[str, np.ndarray]: + """ + Converts a persistence diagram plot into a barcode representation. + + Args: + plot: The plot object containing persistence diagram data. + + Returns: + Dict[str, np.ndarray]: A dictionary where keys are homology types, and values are arrays of intervals. + """ data = plot["data"] homologies = {} - for h in data: - if h["name"] is None: + for homology in data: + if homology["name"] is None: continue - homologies[h["name"]] = list(zip(h["x"], h["y"])) + homologies[homology["name"]] = list(zip(homology["x"], homology["y"])) - for h in homologies.keys(): - homologies[h] = sorted(homologies[h], key=lambda x: x[0]) + for key in homologies.keys(): + homologies[key] = sorted(homologies[key], key=lambda x: x[0]) return homologies -def plot_barcode(barcode: Dict[str, np.ndarray]): +def plot_barcode(barcode: Dict[str, np.ndarray]) -> plt.Figure: + """ + Plots a barcode diagram from given intervals. + + Args: + barcode (Dict[str, np.ndarray]): A dictionary containing homology types and their intervals. + + Returns: + plt.Figure: The matplotlib figure containing the barcode diagram. + """ homologies = list(barcode.keys()) - nplots = len(homologies) - fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots))) + np_lots = len(homologies) + fig, axes = plt.subplots(np_lots, figsize=(15, min(10, 5 * np_lots))) - if nplots == 1: - ax = [ax] + if np_lots == 1: + axes = [axes] - for i in range(nplots): - name = homologies[i] - ax[i].set_title(name) - ax[i].set_ylim([-0.05, 1.05]) + for i, name in enumerate(homologies): + axes[i].set_title(name) + axes[i].set_ylim([-0.05, 1.05]) bars = barcode[name] n = len(bars) - for j in range(n): - bar = bars[j] - ax[i].plot([bar[0], bar[1]], [j / n, j / n], color="black") - labels = ["" for _ in range(len(ax[i].get_yticklabels()))] - ax[i].set_yticks(ax[i].get_yticks()) - ax[i].set_yticklabels(labels) - - if nplots == 1: - ax = ax[0] + for j, bar in enumerate(bars): + axes[i].plot([bar[0], bar[1]], [j / n, j / n], color="black") + axes[i].set_yticks([]) + plt.close(fig) return fig -def compute_data_barcode(data: torch.Tensor, hom_type: str, coefs_type: str): +def compute_data_barcode(data: torch.Tensor, hom_type: str, coefficient_type: str) -> Dict[str, np.ndarray]: + """ + Computes a barcode for the given data using persistent homology. + + Args: + data (torch.Tensor): The input data for barcode computation. + hom_type (str): The type of homology to use ("standard", "sparse", "weak"). + coefficient_type (str): The coefficient field for homology computation. + + Returns: + Dict[str, np.ndarray]: The computed barcode as a dictionary. + + Raises: + ValueError: If an invalid hom_type is provided. + """ if hom_type == "standard": - VR = VietorisRipsPersistence( + vr = VietorisRipsPersistence( homology_dimensions=[0], collapse_edges=True, - coeff=int(coefs_type), + coeff=int(coefficient_type), ) elif hom_type == "sparse": - VR = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefs_type)) + vr = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefficient_type)) elif hom_type == "weak": - VR = WeakAlphaPersistence( + vr = WeakAlphaPersistence( homology_dimensions=[0], collapse_edges=True, - coeff=int(coefs_type), + coeff=int(coefficient_type), ) else: - raise Exception('hom_type must be one of: "standard", "sparse", "weak"!') + raise ValueError('hom_type must be one of: "standard", "sparse", "weak"!') - if len(data.shape) > 2: + if data.ndim > 2: data = torch.nn.Flatten()(data) data = data.reshape(1, *data.shape) - diagrams = VR.fit_transform(data) - plot = VR.plot(diagrams) + diagrams = vr.fit_transform(data) + plot = vr.plot(diagrams) return _diagram_to_barcode(plot) def compute_nn_barcode( - model: torch.nn.Module, - x: torch.Tensor, - layer: str, - hom_type: str, - coefs_type: str, -): - act = _get_activation(model, x, layer) - return compute_data_barcode(act, hom_type, coefs_type) - - -# def get_homologies_experimental( -# model: torch.nn.Module, -# x: torch.Tensor, -# layer: str, -# dimensions: Optional[List[int]] = None, -# make_barplot: bool = True, -# rm_empty: bool = True, -# ): -# act = _get_activation(model, x, layer) -# act = act.reshape(1, *act.shape) -# # Dimensions must not be outside layer dimensionality -# N = act.shape[-1] -# dimensions = dimensions if dimensions is not None else [] -# dimensions = [i if i >= 0 else N + i for i in dimensions] -# dimensions = [i for i in dimensions if ((i >= 0) and (i < N))] -# dimensions = list(set(dimensions)) -# VR = VietorisRipsPersistence(homology_dimensions=dimensions, collapse_edges=True) -# diagrams = VR.fit_transform(act) -# plot = VR.plot(diagrams) -# if make_barplot: -# barcode = _diagram_to_barcode(plot) -# if rm_empty: -# barcode = {key: val for key, val in barcode.items() if len(val) > 0} -# return _plot_barcode(barcode) -# else: -# return plot + model: torch.nn.Module, + x: torch.Tensor, + layer: str, + hom_type: str, + coefficient_type: str, +) -> Dict[str, np.ndarray]: + """ + Computes a barcode for a specified layer in a neural network model. + + Args: + model (torch.nn.Module): The neural network model. + x (torch.Tensor): The input data for the model. + layer (str): The layer to extract activations from. + hom_type (str): The type of homology to use ("standard", "sparse", "weak"). + coefficient_type (str): The coefficient field for homology computation. + + Returns: + Dict[str, np.ndarray]: The computed barcode as a dictionary. + """ + activation = _get_activation(model, x, layer) + return compute_data_barcode(activation, hom_type, coefficient_type) From d8c8fe4a6cf8669d08502617f9510ccd838da1f2 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 02:32:52 +0300 Subject: [PATCH 03/16] fix: change return type to float --- eXNN/topology/metrics.py | 183 ++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 90 deletions(-) diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py index d5b5111..85db57b 100644 --- a/eXNN/topology/metrics.py +++ b/eXNN/topology/metrics.py @@ -1,13 +1,15 @@ import heapq +from typing import Dict + import numpy as np def _get_available_metrics(): """ - Returns a dictionary of available metric names and their corresponding functions. + Returns a dictionary mapping metric names to their respective computation functions. Returns: - dict: Mapping of metric names to functions. + Dict[str, callable]: A dictionary of metric computation functions. """ return { # Absolute length-based metrics @@ -32,16 +34,16 @@ def _get_available_metrics(): } -def compute_metric(barcode, metric_name=None): +def compute_metric(barcode: Dict[str, np.ndarray], metric_name: str = None): """ - Compute the specified metric or all available metrics for a given barcode. + Computes specified or all metrics for a given barcode. Args: - barcode (dict): The barcode data containing persistent homology intervals. - metric_name (str, optional): The name of the metric to compute. Defaults to None. + barcode (Dict[str, np.ndarray]): The barcode to compute metrics for. + metric_name (str, optional): The specific metric name to compute. If None, all metrics are computed. Returns: - dict or float: A dictionary of all metrics if metric_name is None, otherwise the value of the specified metric. + float or Dict[str, float]: The computed metric(s). """ metrics = _get_available_metrics() if metric_name is None: @@ -49,256 +51,257 @@ def compute_metric(barcode, metric_name=None): return metrics[metric_name](barcode) -def _get_lengths(barcode): +def _get_lengths(barcode: Dict[str, np.ndarray]): """ - Compute lengths of intervals in the H0 component of the barcode. + Extracts lengths of intervals from a barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - list: A list of interval lengths. + List[float]: A list of interval lengths. """ diag = barcode["H0"] return [d[1] - d[0] for d in diag] -def _compute_longest_interval_metric(barcode): +def _compute_longest_interval_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the length of the longest interval in the barcode. + Computes the maximum interval length in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The length of the longest interval. + float: The maximum interval length. """ lengths = _get_lengths(barcode) - return np.max(lengths).item() + return float(np.max(lengths)) -def _compute_length_mean_metric(barcode): +def _compute_length_mean_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the mean length of intervals in the barcode. + Computes the mean interval length in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The mean length of intervals. + float: The mean interval length. """ lengths = _get_lengths(barcode) - return np.mean(lengths).item() + return float(np.mean(lengths)) -def _compute_length_median_metric(barcode): +def _compute_length_median_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the median length of intervals in the barcode. + Computes the median interval length in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The median length of intervals. + float: The median interval length. """ lengths = _get_lengths(barcode) - return np.median(lengths).item() + return float(np.median(lengths)) -def _compute_length_stdev_metric(barcode): +def _compute_length_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the standard deviation of interval lengths in the barcode. + Computes the standard deviation of interval lengths in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: float: The standard deviation of interval lengths. """ lengths = _get_lengths(barcode) - return np.std(lengths).item() + return float(np.std(lengths)) -def _compute_length_sum_metric(barcode): +def _compute_length_sum_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the sum of interval lengths in the barcode. + Computes the sum of all interval lengths in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The sum of interval lengths. + float: The sum of all interval lengths. """ lengths = _get_lengths(barcode) - return np.sum(lengths).item() + return float(np.sum(lengths)) -def _compute_two_to_one_ratio_metric(barcode): +def _compute_two_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the ratio of the second longest to the longest interval in the barcode. + Computes the ratio of the second largest to the largest interval length. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The 2-to-1 ratio of interval lengths. + float: The ratio of the second largest to the largest interval length. """ lengths = _get_lengths(barcode) value = heapq.nlargest(2, lengths)[1] / lengths[0] - return value + return float(value) -def _compute_three_to_one_ratio_metric(barcode): +def _compute_three_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the ratio of the third longest to the longest interval in the barcode. + Computes the ratio of the third largest to the largest interval length. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The 3-to-1 ratio of interval lengths. + float: The ratio of the third largest to the largest interval length. """ lengths = _get_lengths(barcode) value = heapq.nlargest(3, lengths)[2] / lengths[0] - return value + return float(value) -def _get_entropy(values, normalize): +def _get_entropy(values: np.ndarray, normalize: bool) -> float: """ - Compute the entropy of a set of values. + Computes the entropy of a given distribution. Args: - values (list): The values for which to compute entropy. + values (np.ndarray): The values to compute entropy for. normalize (bool): Whether to normalize the entropy. Returns: float: The computed entropy. """ values_sum = np.sum(values) - entropy = -np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) + entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) if normalize: - entropy /= np.log(values_sum) - return entropy + entropy = entropy / np.log(values_sum) + return float(entropy) -def _compute_entropy_metric(barcode): +def _compute_entropy_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the persistent entropy of the barcode. + Computes the persistent entropy of intervals in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: float: The persistent entropy. """ - return _get_entropy(_get_lengths(barcode), normalize=False).item() + return _get_entropy(_get_lengths(barcode), normalize=False) -def _compute_normed_entropy_metric(barcode): +def _compute_normed_entropy_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the normalized persistent entropy of the barcode. + Computes the normalized persistent entropy of intervals in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: float: The normalized persistent entropy. """ - return _get_entropy(_get_lengths(barcode), normalize=True).item() + return _get_entropy(_get_lengths(barcode), normalize=True) -def _get_births(barcode): +def _get_births(barcode: Dict[str, np.ndarray]) -> np.ndarray: """ - Extract the birth times from the barcode. + Extracts the birth times from the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - np.ndarray: The birth times. + np.ndarray: An array of birth times. """ diag = barcode["H0"] return np.array([x[0] for x in diag]) -def _get_deaths(barcode): +def _get_deaths(barcode: Dict[str, np.ndarray]) -> np.ndarray: """ - Extract the death times from the barcode. + Extracts the death times from the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - np.ndarray: The death times. + np.ndarray: An array of death times. """ diag = barcode["H0"] return np.array([x[1] for x in diag]) -def _compute_snr_metric(barcode): +def _compute_snr_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the signal-to-noise ratio (SNR) of the barcode. + Computes the signal-to-noise ratio (SNR) for the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The SNR value. + float: The computed SNR. """ births = _get_births(barcode) deaths = _get_deaths(barcode) signal = np.mean(deaths - births) noise = np.std(births) - return (signal / noise).item() + snr = signal / noise + return float(snr) -def _compute_births_mean_metric(barcode): +def _compute_births_mean_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the mean of birth times in the barcode. + Computes the mean of birth times in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The mean birth time. + float: The mean of birth times. """ - return np.mean(_get_births(barcode)).item() + return float(np.mean(_get_births(barcode))) -def _compute_births_stdev_metric(barcode): +def _compute_births_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the standard deviation of birth times in the barcode. + Computes the standard deviation of birth times in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: float: The standard deviation of birth times. """ - return np.std(_get_births(barcode)).item() + return float(np.std(_get_births(barcode))) -def _compute_deaths_mean_metric(barcode): +def _compute_deaths_mean_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the mean of death times in the barcode. + Computes the mean of death times in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: - float: The mean death time. + float: The mean of death times. """ - return np.mean(_get_deaths(barcode)).item() + return float(np.mean(_get_deaths(barcode))) -def _compute_deaths_stdev_metric(barcode): +def _compute_deaths_stdev_metric(barcode: Dict[str, np.ndarray]) -> float: """ - Compute the standard deviation of death times in the barcode. + Computes the standard deviation of death times in the barcode. Args: - barcode (dict): The barcode data. + barcode (Dict[str, np.ndarray]): The barcode to process. Returns: float: The standard deviation of death times. """ - return np.std(_get_deaths(barcode)).item() + return float(np.std(_get_deaths(barcode))) From 4ac776c7601c3ba713a761d70d4ae985ce7f9e49 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 02:37:34 +0300 Subject: [PATCH 04/16] feat: update doc strings --- eXNN/topology/api.py | 58 ++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/eXNN/topology/api.py b/eXNN/topology/api.py index 70e70cb..36d4433 100644 --- a/eXNN/topology/api.py +++ b/eXNN/topology/api.py @@ -10,21 +10,22 @@ def get_data_barcode( data: torch.Tensor, hom_type: str, - coefs_type: str, + coefficient_type: str, ) -> Dict[str, np.ndarray]: - """This function computes persistent homologies for a cloud of points as barcodes. + """ + Computes persistent homologies for a cloud of points as barcodes. Args: - data (torch.Tensor): input data of shape NxC1x...xCk, + data (torch.Tensor): Input data of shape NxC1x...xCk, where N is the number of data points, - C1,...,Ck are dimensions of each data point - hom_type (str): homotopy type - coefs_type (str): coefficients type + C1,...,Ck are dimensions of each data point. + hom_type (str): Homotopy type. + coefficient_type (str): Coefficients type. Returns: - Dict[str, np.ndarray]: barcode + Dict[str, np.ndarray]: Barcode. """ - return homologies.compute_data_barcode(data, hom_type, coefs_type) + return homologies.compute_data_barcode(data, hom_type, coefficient_type) def get_nn_barcodes( @@ -32,40 +33,39 @@ def get_nn_barcodes( data: torch.Tensor, layers: List[str], hom_type: str, - coefs_type: str, + coefficient_type: str, ) -> Dict[str, Dict[str, np.ndarray]]: """ - The function plots persistent homologies for latent representations - on different levels of the neural network as barcodes. + Computes persistent homologies for latent representations on different + levels of the neural network as barcodes. Args: - model (torch.nn.Module): neural network - data (torch.Tensor): input data of shape NxC1x...xCk, + model (torch.nn.Module): Neural network. + data (torch.Tensor): Input data of shape NxC1x...xCk, where N is the number of data points, - C1,...,Ck are dimensions of each data point - layers (List[str]): list of layers for visualization. Defaults to None. - If None, visualization for all layers is performed - hom_type (str): homotopy type - coefs_type (str): coefficients type + C1,...,Ck are dimensions of each data point. + layers (List[str]): List of layers for visualization. + hom_type (str): Homotopy type. + coefficient_type (str): Coefficients type. Returns: - Dict[str, Dict[str, np.ndarray]]: dictionary with a barcode for each layer + Dict[str, Dict[str, np.ndarray]]: Dictionary with a barcode for each layer. """ res = {} for layer in layers: - res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefs_type) + res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefficient_type) return res def plot_barcode(barcode: Dict[str, np.ndarray]) -> matplotlib.figure.Figure: """ - The function creates a plot of a persistent homologies barcode. + Creates a plot of a persistent homologies barcode. Args: - barcode (Dict[str, np.ndarray]): barcode + barcode (Dict[str, np.ndarray]): Barcode. Returns: - matplotlib.figure.Figure: a plot of the barcode + matplotlib.figure.Figure: Plot of the barcode. """ return homologies.plot_barcode(barcode) @@ -74,15 +74,15 @@ def evaluate_barcode( barcode: Dict[str, np.ndarray], metric_name: Optional[str] = None ) -> Union[float, Dict[str, float]]: """ - The function evaluates a persistent homologies barcode with a metric. + Evaluates a persistent homologies barcode with a metric. Args: - barcode (Dict[str, np.ndarray]): barcode - metric_name (Optional[str]): metric name - (if `None` all available metrics values are computed) + barcode (Dict[str, np.ndarray]): Barcode. + metric_name (Optional[str]): Metric name (if None, all available metrics + values are computed). Returns: - Union(float, Dict[str, float]): float if metric is specified - or a dictionary with value of each available metric + Union[float, Dict[str, float]]: Float if metric is specified, or a + dictionary with values of each available metric. """ return metrics.compute_metric(barcode, metric_name) From e8c2ef8028f9d785e8778d953748e9c3edf5ac80 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 02:51:34 +0300 Subject: [PATCH 05/16] feat: add docstrings --- eXNN/bayes/wrapper.py | 203 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 36 deletions(-) diff --git a/eXNN/bayes/wrapper.py b/eXNN/bayes/wrapper.py index 8e1ed43..2681faf 100644 --- a/eXNN/bayes/wrapper.py +++ b/eXNN/bayes/wrapper.py @@ -3,18 +3,29 @@ import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional from torch.distributions import Beta class ModuleBayesianWrapper(nn.Module): + """ + A wrapper for neural network layers to apply Bayesian-style dropout or noise during training. + + Args: + layer (nn.Module): The layer to wrap (e.g., nn.Linear, nn.Conv2d). + p (Optional[float]): Dropout probability for simple dropout. Mutually exclusive with `a`, `b`, and `sigma`. + a (Optional[float]): Alpha parameter for Beta distribution dropout. Used with `b`. + b (Optional[float]): Beta parameter for Beta distribution dropout. Used with `a`. + sigma (Optional[float]): Standard deviation for Gaussian noise. Mutually exclusive with `p`, `a`, and `b`. + """ + def __init__( - self, - layer: nn.Module, - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, - sigma: Optional[float] = None, + self, + layer: nn.Module, + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, + sigma: Optional[float] = None, ): super(ModuleBayesianWrapper, self).__init__() @@ -36,6 +47,16 @@ def __init__( self.p, self.a, self.b, self.sigma = p, a, b, sigma def augment_weights(self, weights, bias): + """ + Apply the specified noise or dropout to the weights and bias. + + Args: + weights (torch.Tensor): The weights of the layer. + bias (torch.Tensor): The bias of the layer (can be None). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The augmented weights and bias. + """ # Check if dropout is chosen if (self.p is not None) or (self.a is not None and self.b is not None): @@ -45,10 +66,10 @@ def augment_weights(self, weights, bias): else: p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample() - weights = F.dropout(weights, p, training=True) + weights = functional.dropout(weights, p, training=True) if bias is not None: # In layers we sometimes have the ability to set bias to None - bias = F.dropout(bias, p, training=True) + bias = functional.dropout(bias, p, training=True) else: # If gauss is chosen, then apply it @@ -60,11 +81,20 @@ def augment_weights(self, weights, bias): return weights, bias def forward(self, x): + """ + Forward pass through the layer with augmented weights. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ weight, bias = self.augment_weights(self.layer.weight, self.layer.bias) if isinstance(self.layer, nn.Linear): - return F.linear(x, weight, bias) + return functional.linear(x, weight, bias) elif type(self.layer) in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: return self.layer._conv_forward(x, weight, bias) else: @@ -72,6 +102,18 @@ def forward(self, x): def replace_modules_with_wrapper(model, wrapper_module, params): + """ + Recursively replaces layers in a model with a Bayesian wrapper. + + Args: + model (nn.Module): The model containing layers to replace. + wrapper_module (type): The wrapper class. + params (dict): Parameters for the wrapper. + + Returns: + nn.Module: The model with wrapped layers. + """ + if type(model) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]: return wrapper_module(model, **params) @@ -88,12 +130,26 @@ def replace_modules_with_wrapper(model, wrapper_module, params): class NetworkBayes(nn.Module): + """ + Bayesian network with standard dropout. + + Args: + model (nn.Module): The base model. + dropout_p (float): Dropout probability. + """ + def __init__( - self, - model: nn.Module, - dropout_p: float, + self, + model: nn.Module, + dropout_p: float, ): + """ + Initialize the NetworkBayes with standard dropout. + Args: + model (nn.Module): The base model to wrap with Bayesian dropout. + dropout_p (float): Dropout probability for the Bayesian wrapper. + """ super(NetworkBayes, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -103,11 +159,20 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -125,13 +190,29 @@ def mean_forward( # calculate mean and std after applying bayesian with beta distribution class NetworkBayesBeta(nn.Module): + """ + Bayesian network with Beta distribution dropout. + + Args: + model (nn.Module): The base model. + alpha (float): Alpha parameter for the Beta distribution. + beta (float): Beta parameter for the Beta distribution. + """ + def __init__( - self, - model: torch.nn.Module, - alpha: float, - beta: float, + self, + model: torch.nn.Module, + alpha: float, + beta: float, ): - + """ + Initialize the NetworkBayesBeta with Beta distribution dropout. + + Args: + model (nn.Module): The base model to wrap with Bayesian Beta dropout. + alpha (float): Alpha parameter of the Beta distribution. + beta (float): Beta parameter of the Beta distribution. + """ super(NetworkBayesBeta, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -141,11 +222,20 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -163,12 +253,26 @@ def mean_forward( class NetworkBayesGauss(nn.Module): + """ + Bayesian network with Gaussian noise. + + Args: + model (nn.Module): The base model. + sigma (float): Standard deviation of the Gaussian noise. + """ + def __init__( - self, - model: torch.nn.Module, - sigma: float, + self, + model: torch.nn.Module, + sigma: float, ): + """ + Initialize the NetworkBayesGauss with Gaussian noise. + Args: + model (nn.Module): The base model to wrap with Bayesian Gaussian noise. + sigma (float): Standard deviation of the Gaussian noise to apply. + """ super(NetworkBayesGauss, self).__init__() self.model = copy.deepcopy(model) self.model = replace_modules_with_wrapper( @@ -178,11 +282,20 @@ def __init__( ) def mean_forward( - self, - data: torch.Tensor, - n_iter: int, + self, + data: torch.Tensor, + n_iter: int, ): + """ + Perform forward passes to estimate the mean and standard deviation of outputs. + Args: + data (torch.Tensor): Input tensor. + n_iter (int): Number of stochastic forward passes. + + Returns: + torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + """ results = [] for _ in range(n_iter): results.append(self.model.forward(data)) @@ -200,13 +313,28 @@ def mean_forward( def create_dropout_bayesian_wrapper( - model: torch.nn.Module, - mode: Optional[str] = "basic", - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, - sigma: Optional[float] = None, + model: torch.nn.Module, + mode: Optional[str] = "basic", + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, + sigma: Optional[float] = None, ) -> torch.nn.Module: + """ + Creates a Bayesian network with the specified dropout mode. + + Args: + model (nn.Module): The base model. + mode (str): The dropout mode ("basic", "beta", "gauss"). + p (Optional[float]): Dropout probability for "basic" mode. + a (Optional[float]): Alpha parameter for "beta" mode. + b (Optional[float]): Beta parameter for "beta" mode. + sigma (Optional[float]): Standard deviation for "gauss" mode. + + Returns: + nn.Module: The Bayesian network. + """ + if mode == "basic": net = NetworkBayes(model, p) @@ -216,4 +344,7 @@ def create_dropout_bayesian_wrapper( elif mode == 'gauss': net = NetworkBayesGauss(model, sigma) + else: + raise ValueError("Mode should be one of ('basic', 'beta', 'gauss').") + return net From c0a88f42b4174003f1dc1a1e028fa32efd30da8d Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:31:35 +0300 Subject: [PATCH 06/16] feat: change tests structure --- tests/test_bayesian/__init__.py | 0 tests/test_bayesian/test_bayes.py | 60 ++++++ tests/test_topology/__init__.py | 0 tests/test_topology/test_barcode_general.py | 57 ++++++ tests/test_topology/test_barcode_metrics.py | 50 +++++ tests/test_utils.py | 53 ------ tests/test_visualization/__init__.py | 0 .../test_check_random_input.py | 18 ++ tests/test_visualization/test_reduce_dim.py | 42 +++++ .../test_visualization/test_visualization.py | 119 ++++++++++++ tests/tests.py | 172 ------------------ tests/utils/__init__.py | 0 tests/utils/test_utils.py | 124 +++++++++++++ 13 files changed, 470 insertions(+), 225 deletions(-) create mode 100644 tests/test_bayesian/__init__.py create mode 100644 tests/test_bayesian/test_bayes.py create mode 100644 tests/test_topology/__init__.py create mode 100644 tests/test_topology/test_barcode_general.py create mode 100644 tests/test_topology/test_barcode_metrics.py delete mode 100644 tests/test_utils.py create mode 100644 tests/test_visualization/__init__.py create mode 100644 tests/test_visualization/test_check_random_input.py create mode 100644 tests/test_visualization/test_reduce_dim.py create mode 100644 tests/test_visualization/test_visualization.py delete mode 100644 tests/tests.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_utils.py diff --git a/tests/test_bayesian/__init__.py b/tests/test_bayesian/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py new file mode 100644 index 0000000..5b40f57 --- /dev/null +++ b/tests/test_bayesian/test_bayes.py @@ -0,0 +1,60 @@ +import torch + +import eXNN.bayes as bayes_api +import tests.utils.test_utils as utils + + +def _test_bayes_prediction(mode: str): + """ + Helper function to test Bayesian wrappers with different configurations. + + Args: + mode (str): Bayesian wrapper mode ("basic", "beta", "gauss"). + + Verifies: + - Output is a dictionary with keys "mean" and "std". + - Shapes of "mean" and "std" match the expected shape. + """ + params = { + "basic": dict(mode="basic", p=0.5), + "beta": dict(mode="beta", a=0.9, b=0.2), + "gauss": dict(sigma=1e-2), + } + + n, dim, data = utils.create_testing_data() + num_classes = 17 + model = utils.create_testing_model(num_classes=num_classes) + n_iter = 7 + + if mode != 'gauss': + res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + else: + res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values({"mean", "std"}, set(res.keys()), "Wrong dictionary keys") + utils.compare_values(torch.Size([n, num_classes]), res["mean"].shape, "Wrong mean shape") + utils.compare_values(torch.Size([n, num_classes]), res["std"].shape, "Wrong mean std") + + +def test_basic_bayes_wrapper(): + """ + Test the basic Bayesian wrapper. + """ + _test_bayes_prediction("basic") + + +def test_beta_bayes_wrapper(): + """ + Test the beta Bayesian wrapper. + """ + _test_bayes_prediction("beta") + + +def test_gauss_bayes_wrapper(): + """ + Test the Gaussian Bayesian wrapper. + """ + _test_bayes_prediction("gauss") + diff --git a/tests/test_topology/__init__.py b/tests/test_topology/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py new file mode 100644 index 0000000..189e194 --- /dev/null +++ b/tests/test_topology/test_barcode_general.py @@ -0,0 +1,57 @@ +import matplotlib + +import eXNN.topology as topology_api +import tests.utils.test_utils as utils + + +def test_data_barcode(): + """ + Test generating a barcode from data. + + Verifies: + - Output is a dictionary. + """ + n, dim, data = utils.create_testing_data() + res = topology_api.get_data_barcode(data, "standard", "3") + utils.compare_values(dict, type(res), "Wrong result type") + + +def test_nn_barcodes(): + """ + Test generating barcodes for a neural network. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the specified layers. + - Each dictionary value is a dictionary. + """ + n, dim, data = utils.create_testing_data() + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + + res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3") + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys") + + for layer, barcode in res.items(): + utils.compare_values( + dict, + type(barcode), + f"Wrong result type for key {layer}", + ) + + +def test_barcode_plot(): + """ + Test generating a barcode plot. + + Verifies: + - Output is a matplotlib.figure.Figure. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + plot = topology_api.plot_barcode(barcode) + utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type") + diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py new file mode 100644 index 0000000..9d8eaa2 --- /dev/null +++ b/tests/test_topology/test_barcode_metrics.py @@ -0,0 +1,50 @@ +import eXNN.topology as topology_api +import tests.utils.test_utils as utils + + +def test_barcode_evaluate_all_metrics(): + """ + Test evaluating all metrics for a barcode. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the expected metric names. + - Each metric value is a float. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode) + utils.compare_values(dict, type(result), "Wrong result type") + all_metric_names = [ + "h", + "max_length", + "mean_birth", + "mean_death", + "mean_length", + "median_length", + "normh", + "ratio_2_1", + "ratio_3_1", + "snr", + "stdev_birth", + "stdev_death", + "stdev_length", + "sum_length", + ] + utils.compare_values(all_metric_names, sorted(result.keys())) + for name, value in result.items(): + utils.compare_values(float, type(value), f"Wrong result type for metric {name}") + + +def test_barcode_evaluate_one_metric(): + """ + Test evaluating a single metric for a barcode. + + Verifies: + - Output is a float. + """ + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode, metric_name="mean_length") + utils.compare_values(float, type(result), "Wrong result type") + diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 08f08e8..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from collections import OrderedDict - -import torch -import torch.nn as nn - - -def _form_message_header(message_header=None): - return "Value mismatch" if message_header is None else message_header - - -def compare_values(expected, got, message_header=None): - assert ( - expected == got - ), f"{_form_message_header(message_header)}: expected {expected}, got {got}" - - -def create_testing_data(): - N = 20 - dim = 256 - data = torch.randn((N, dim)) - return N, dim, data - - -def create_testing_model(num_classes=10): - return nn.Sequential( - OrderedDict( - [ - ("first_layer", nn.Linear(256, 128)), - ("second_layer", nn.Linear(128, 64)), - ("third_layer", nn.Linear(64, num_classes)), - ], - ), - ) - - -class ExtractTensor(nn.Module): - def forward(self, x): - tensor, _ = x - x = x.to(torch.float32) - return tensor[:, :] - - -def create_testing_model_lstm(num_classes=10): - return nn.Sequential( - OrderedDict( - [ - ('first_layer', nn.LSTM(256, 128, 1, batch_first=True)), - ('extract', ExtractTensor()), - ('second_layer', nn.Linear(128, 64)), - ('third_layer', nn.Linear(64, num_classes)), - ], - ), - ) diff --git a/tests/test_visualization/__init__.py b/tests/test_visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_visualization/test_check_random_input.py b/tests/test_visualization/test_check_random_input.py new file mode 100644 index 0000000..5bd6360 --- /dev/null +++ b/tests/test_visualization/test_check_random_input.py @@ -0,0 +1,18 @@ +import torch + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def test_check_random_input(): + """ + Test generating random input tensors with a specific shape. + + Verifies: + - Output type is a torch.Tensor. + - Output shape matches the specified shape. + """ + shape = [5, 17, 81, 37] + data = viz_api.get_random_input(shape) + utils.compare_values(type(data), torch.Tensor, "Wrong result type") + utils.compare_values(torch.Size(shape), data.shape, "Wrong result shape") diff --git a/tests/test_visualization/test_reduce_dim.py b/tests/test_visualization/test_reduce_dim.py new file mode 100644 index 0000000..a5b1723 --- /dev/null +++ b/tests/test_visualization/test_reduce_dim.py @@ -0,0 +1,42 @@ +import numpy as np + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def _check_reduce_dim(mode): + """ + Helper function to test dimensionality reduction with a given mode. + + Args: + mode (str): Dimensionality reduction mode (e.g., "umap", "pca"). + + Verifies: + - Output type is a numpy.ndarray. + - Output shape is (n_samples, 2). + """ + n, dim, data = utils.create_testing_data() + reduced_data = viz_api.reduce_dim(data, mode) + utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type") + utils.compare_values((n, 2), reduced_data.shape, "Wrong result shape") + + +def test_reduce_dim_umap(): + """ + Test dimensionality reduction using UMAP. + + Uses: + _check_reduce_dim with mode "umap". + """ + _check_reduce_dim("umap") + + +def test_reduce_dim_pca(): + """ + Test dimensionality reduction using PCA. + + Uses: + _check_reduce_dim with mode "pca". + """ + _check_reduce_dim("pca") + diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py new file mode 100644 index 0000000..2471c0c --- /dev/null +++ b/tests/test_visualization/test_visualization.py @@ -0,0 +1,119 @@ +import matplotlib +import plotly +import torch + +import eXNN.visualization as viz_api +import tests.utils.test_utils as utils + + +def test_visualization(): + """ + Test visualization of layer manifolds using UMAP. + + Verifies: + - Output is a dictionary. + - Dictionary has the correct keys corresponding to the input and specified layers. + - Each dictionary value is a matplotlib.figure.Figure. + """ + + n, dim, data = utils.create_testing_data() + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(3, len(res), "Wrong dictionary length") + utils.compare_values( + set(["input"] + layers), + set(res.keys()), + "Wrong dictionary keys", + ) + + for key, plot in res.items(): + utils.compare_values( + matplotlib.figure.Figure, + type(plot), + f"Wrong value type for key {key}", + ) + + +def test_embed_visualization(): + """ + Test visualization of recurrent layer manifolds with embeddings. + + Verifies: + - Output is a dictionary. + - Dictionary keys match the specified layers. + - Each dictionary value is a plotly.graph_objs.Figure. + """ + data = torch.randn((20, 1, 256)) + labels = torch.randn(20) + model = utils.create_testing_model_lstm() + layers = ["second_layer", "third_layer"] + + res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels) + + utils.compare_values(dict, type(res), "Wrong result type") + utils.compare_values(2, len(res), "Wrong dictionary length") + utils.compare_values( + set(layers), + set(res.keys()), + "Wrong dictionary keys", + ) + for key, plot in res.items(): + utils.compare_values( + plotly.graph_objs.Figure, + type(plot), + f"Wrong value type for key {key}", + ) + + +def test_all_visualizations(): + """ + Test all visualizations (layer manifolds and recurrent layer manifolds). + + Verifies: + - Visualization of layer manifolds works with UMAP for regular layers. + - Visualization of recurrent layer manifolds works with embeddings. + - Correct output types for both types of plots (matplotlib and Plotly). + """ + # Test for layer manifolds visualization using UMAP + n, dim, data = utils.create_testing_data() + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers) + + utils.compare_values(dict, type(res), "Wrong result type for layer manifolds") + utils.compare_values(3, len(res), "Wrong dictionary length for layer manifolds") + utils.compare_values( + set(["input"] + layers), + set(res.keys()), + "Wrong dictionary keys for layer manifolds", + ) + + for key, plot in res.items(): + utils.compare_values( + matplotlib.figure.Figure, + type(plot), + f"Wrong value type for key {key} in layer manifolds", + ) + + # Test for recurrent layer manifolds visualization using embeddings + data = torch.randn((20, 1, 256)) + labels = torch.randn(20) + model = utils.create_testing_model_lstm() + res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels) + + utils.compare_values(dict, type(res), "Wrong result type for recurrent layer manifolds") + utils.compare_values(2, len(res), "Wrong dictionary length for recurrent layer manifolds") + utils.compare_values( + set(layers), + set(res.keys()), + "Wrong dictionary keys for recurrent layer manifolds", + ) + for key, plot in res.items(): + utils.compare_values( + plotly.graph_objs.Figure, + type(plot), + f"Wrong value type for key {key} in recurrent layer manifolds", + ) diff --git a/tests/tests.py b/tests/tests.py deleted file mode 100644 index 43edb93..0000000 --- a/tests/tests.py +++ /dev/null @@ -1,172 +0,0 @@ -import matplotlib -import numpy as np -import plotly -import torch - -import eXNN.bayes as bayes_api -import eXNN.topology as topology_api -import eXNN.visualization as viz_api -import tests.test_utils as utils - - -def test_check_random_input(): - shape = [5, 17, 81, 37] - data = viz_api.get_random_input(shape) - utils.compare_values(type(data), torch.Tensor, "Wrong result type") - utils.compare_values(torch.Size(shape), data.shape, "Wrong result shape") - - -def _check_reduce_dim(mode): - N, dim, data = utils.create_testing_data() - reduced_data = viz_api.reduce_dim(data, mode) - utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type") - utils.compare_values((N, 2), reduced_data.shape, "Wrong result shape") - - -def test_reduce_dim_umap(): - _check_reduce_dim("umap") - - -def test_reduce_dim_pca(): - _check_reduce_dim("pca") - - -def test_visualization(): - N, dim, data = utils.create_testing_data() - model = utils.create_testing_model() - layers = ["second_layer", "third_layer"] - res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers) - - utils.compare_values(dict, type(res), "Wrong result type") - utils.compare_values(3, len(res), "Wrong dictionary length") - utils.compare_values( - set(["input"] + layers), - set(res.keys()), - "Wrong dictionary keys", - ) - for key, plot in res.items(): - utils.compare_values( - matplotlib.figure.Figure, - type(plot), - f"Wrong value type for key {key}", - ) - - -def test_embed_visualization(): - data = torch.randn((20, 1, 256)) - labels = torch.randn((20)) - model = utils.create_testing_model_lstm() - layers = ["second_layer", "third_layer"] - res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", - data, layers=layers, labels=labels) - utils.compare_values(dict, type(res), "Wrong result type") - utils.compare_values(2, len(res), "Wrong dictionary length") - utils.compare_values( - set(layers), - set(res.keys()), - "Wrong dictionary keys", - ) - for key, plot in res.items(): - utils.compare_values( - plotly.graph_objs.Figure, - type(plot), - f"Wrong value type for key {key}", - ) - - -def _test_bayes_prediction(mode: str): - params = { - "basic": dict(mode="basic", p=0.5), - "beta": dict(mode="beta", a=0.9, b=0.2), - "gauss": dict(sigma=1e-2), - } - - N, dim, data = utils.create_testing_data() - num_classes = 17 - model = utils.create_testing_model(num_classes=num_classes) - n_iter = 7 - if mode != 'gauss': - res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) - else: - res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, - n_iter=n_iter) - - utils.compare_values(dict, type(res), "Wrong result type") - utils.compare_values(2, len(res), "Wrong dictionary length") - utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys") - utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape") - utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std") - - -def test_basic_bayes_wrapper(): - _test_bayes_prediction("basic") - - -def test_beta_bayes_wrapper(): - _test_bayes_prediction("beta") - - -def test_gauss_bayes_wrapper(): - _test_bayes_prediction("gauss") - - -def test_data_barcode(): - N, dim, data = utils.create_testing_data() - res = topology_api.get_data_barcode(data, "standard", "3") - utils.compare_values(dict, type(res), "Wrong result type") - - -def test_nn_barcodes(): - N, dim, data = utils.create_testing_data() - model = utils.create_testing_model() - layers = ["second_layer", "third_layer"] - res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3") - utils.compare_values(dict, type(res), "Wrong result type") - utils.compare_values(2, len(res), "Wrong dictionary length") - utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys") - for layer, barcode in res.items(): - utils.compare_values( - dict, - type(barcode), - f"Wrong result type for key {layer}", - ) - - -def test_barcode_plot(): - N, dim, data = utils.create_testing_data() - barcode = topology_api.get_data_barcode(data, "standard", "3") - plot = topology_api.plot_barcode(barcode) - utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type") - - -def test_barcode_evaluate_all_metrics(): - N, dim, data = utils.create_testing_data() - barcode = topology_api.get_data_barcode(data, "standard", "3") - result = topology_api.evaluate_barcode(barcode) - utils.compare_values(dict, type(result), "Wrong result type") - all_metric_names = [ - "h", - "max_length", - "mean_birth", - "mean_death", - "mean_length", - "median_length", - "normh", - "ratio_2_1", - "ratio_3_1", - "snr", - "stdev_birth", - "stdev_death", - "stdev_length", - "sum_length", - ] - utils.compare_values(all_metric_names, sorted(result.keys())) - for name, value in result.items(): - utils.compare_values(float, type(value), f"Wrong result type for metric {name}") - - -def test_barcode_evaluate_one_metric(): - N, dim, data = utils.create_testing_data() - barcode = topology_api.get_data_barcode(data, "standard", "3") - result = topology_api.evaluate_barcode(barcode, metric_name="mean_length") - utils.compare_values(float, type(result), "Wrong result type") diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 0000000..47c85f7 --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,124 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + + +def _form_message_header(message_header=None): + """ + Forms the header for assertion messages. + + Args: + message_header (str, optional): Custom message header. Defaults to None. + + Returns: + str: The custom header if provided, otherwise "Value mismatch". + """ + return "Value mismatch" if message_header is None else message_header + + +def compare_values(expected, got, message_header=None): + """ + Compares two values and raises an assertion error if they do not match. + + Args: + expected: The expected value. + got: The value to compare against the expected value. + message_header (str, optional): Custom header for the assertion message. Defaults to None. + + Raises: + AssertionError: If the values do not match, an error with the message header is raised. + """ + assert ( + expected == got + ), f"{_form_message_header(message_header)}: expected {expected}, got {got}" + + +def create_testing_data(): + """ + Creates synthetic testing data. + + Returns: + tuple: A tuple containing: + - N (int): Number of samples. + - dim (int): Dimensionality of each sample. + - data (torch.Tensor): Randomly generated data tensor of shape (N, dim). + """ + n = 20 + dim = 256 + data = torch.randn((n, dim)) + return n, dim, data + + +def create_testing_model(num_classes=10): + """ + Creates a simple feedforward neural network for testing. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 10. + + Returns: + nn.Sequential: A sequential model with three layers: + - Linear layer (input_dim=256, output_dim=128). + - Linear layer (input_dim=128, output_dim=64). + - Linear layer (input_dim=64, output_dim=num_classes). + """ + return nn.Sequential( + OrderedDict( + [ + ("first_layer", nn.Linear(256, 128)), + ("second_layer", nn.Linear(128, 64)), + ("third_layer", nn.Linear(64, num_classes)), + ], + ), + ) + + +class ExtractTensor(nn.Module): + """ + A custom PyTorch module to extract and process tensors. + + This module extracts the first tensor from a tuple, converts it to + float32, and returns its values. + """ + + @staticmethod + def forward(x): + """ + Forward pass for tensor extraction and processing. + + Args: + x (tuple): Input tuple where the first element is the tensor to be processed. + + Returns: + torch.Tensor: Processed tensor (converted to float32). + """ + tensor, _ = x + tensor = tensor.to(torch.float32) + return tensor[:, :] + + +def create_testing_model_lstm(num_classes=10): + """ + Creates a recurrent neural network with LSTM layers for testing. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 10. + + Returns: + nn.Sequential: A sequential model with the following layers: + - LSTM layer (input_dim=256, hidden_dim=128, num_layers=1). + - ExtractTensor layer to process LSTM output. + - Linear layer (input_dim=128, output_dim=64). + - Linear layer (input_dim=64, output_dim=num_classes). + """ + return nn.Sequential( + OrderedDict( + [ + ('first_layer', nn.LSTM(256, 128, 1, batch_first=True)), + ('extract', ExtractTensor()), + ('second_layer', nn.Linear(128, 64)), + ('third_layer', nn.Linear(64, num_classes)), + ], + ), + ) From 0f5ba794045554cfd2f578426acdfb8a37df05d6 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:34:23 +0300 Subject: [PATCH 07/16] feat: add new tests and remove extra one for visualization --- tests/test_bayesian/test_bayes.py | 9 ++++ tests/test_topology/test_barcode_general.py | 26 ++++++++++ tests/test_topology/test_barcode_metrics.py | 37 ++++++++++++++ tests/test_visualization/test_reduce_dim.py | 13 +++++ .../test_visualization/test_visualization.py | 51 ------------------- 5 files changed, 85 insertions(+), 51 deletions(-) diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py index 5b40f57..290fb7c 100644 --- a/tests/test_bayesian/test_bayes.py +++ b/tests/test_bayesian/test_bayes.py @@ -58,3 +58,12 @@ def test_gauss_bayes_wrapper(): """ _test_bayes_prediction("gauss") + +def test_all_bayes_wrappers(): + """ + Test all Bayesian wrappers (basic, beta, and gauss) with a single general test. + """ + modes = ["basic", "beta", "gauss"] + + for mode in modes: + _test_bayes_prediction(mode) diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py index 189e194..174ce28 100644 --- a/tests/test_topology/test_barcode_general.py +++ b/tests/test_topology/test_barcode_general.py @@ -55,3 +55,29 @@ def test_barcode_plot(): plot = topology_api.plot_barcode(barcode) utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type") + +def test_all_barcodes(): + """ + Test all barcode-related functions (data barcode, NN barcodes, and barcode plot) together. + + Verifies: + - Data barcode is generated correctly. + - Neural network barcodes are generated correctly for specified layers. + - Barcode plot is generated correctly. + """ + # Test data barcode + n, dim, data = utils.create_testing_data() + res = topology_api.get_data_barcode(data, "standard", "3") + utils.compare_values(dict, type(res), "Wrong result type for data barcode") + + # Test NN barcodes + model = utils.create_testing_model() + layers = ["second_layer", "third_layer"] + nn_barcodes = topology_api.get_nn_barcodes(model, data, layers, "standard", "3") + utils.compare_values(dict, type(nn_barcodes), "Wrong result type for NN barcodes") + utils.compare_values(2, len(nn_barcodes), "Wrong dictionary length for NN barcodes") + utils.compare_values(set(layers), set(nn_barcodes.keys()), "Wrong dictionary keys for NN barcodes") + + # Test barcode plot + barcode_plot = topology_api.plot_barcode(res) + utils.compare_values(matplotlib.figure.Figure, type(barcode_plot), "Wrong result type for barcode plot") diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py index 9d8eaa2..53a3a29 100644 --- a/tests/test_topology/test_barcode_metrics.py +++ b/tests/test_topology/test_barcode_metrics.py @@ -48,3 +48,40 @@ def test_barcode_evaluate_one_metric(): result = topology_api.evaluate_barcode(barcode, metric_name="mean_length") utils.compare_values(float, type(result), "Wrong result type") + +def test_barcode_evaluate_all_metrics_and_individual(): + """ + Test evaluating all metrics for a barcode and individual metrics evaluation. + + Verifies: + - Output for evaluating all metrics is a dictionary. + - Dictionary keys match the expected metric names. + - Each metric value is a float. + - Individual metrics can be correctly evaluated. + """ + # Test for all metrics + n, dim, data = utils.create_testing_data() + barcode = topology_api.get_data_barcode(data, "standard", "3") + result = topology_api.evaluate_barcode(barcode) + + # Check that the result is a dictionary + utils.compare_values(dict, type(result), "Wrong result type") + + # List of expected metric names + all_metric_names = [ + "h", "max_length", "mean_birth", "mean_death", "mean_length", + "median_length", "normh", "ratio_2_1", "ratio_3_1", "snr", + "stdev_birth", "stdev_death", "stdev_length", "sum_length" + ] + + # Check that the dictionary keys match the expected metric names + utils.compare_values(all_metric_names, sorted(result.keys()), "Wrong dictionary keys") + + # Ensure all metric values are floats + for name, value in result.items(): + utils.compare_values(float, type(value), f"Wrong result type for metric {name}") + + # Test for evaluating individual metrics + for metric_name in all_metric_names: + individual_result = topology_api.evaluate_barcode(barcode, metric_name=metric_name) + utils.compare_values(float, type(individual_result), f"Wrong result type for individual metric {metric_name}") diff --git a/tests/test_visualization/test_reduce_dim.py b/tests/test_visualization/test_reduce_dim.py index a5b1723..8ed32b9 100644 --- a/tests/test_visualization/test_reduce_dim.py +++ b/tests/test_visualization/test_reduce_dim.py @@ -40,3 +40,16 @@ def test_reduce_dim_pca(): """ _check_reduce_dim("pca") + +def test_all_reduce_dim_methods(): + """ + Test dimensionality reduction using all available methods (e.g., UMAP, PCA). + + Verifies: + - Dimensionality reduction works for each method. + - Output is of type numpy.ndarray and shape (n_samples, 2). + """ + modes = ["umap", "pca"] + + for mode in modes: + _check_reduce_dim(mode) diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py index 2471c0c..31a6206 100644 --- a/tests/test_visualization/test_visualization.py +++ b/tests/test_visualization/test_visualization.py @@ -66,54 +66,3 @@ def test_embed_visualization(): type(plot), f"Wrong value type for key {key}", ) - - -def test_all_visualizations(): - """ - Test all visualizations (layer manifolds and recurrent layer manifolds). - - Verifies: - - Visualization of layer manifolds works with UMAP for regular layers. - - Visualization of recurrent layer manifolds works with embeddings. - - Correct output types for both types of plots (matplotlib and Plotly). - """ - # Test for layer manifolds visualization using UMAP - n, dim, data = utils.create_testing_data() - model = utils.create_testing_model() - layers = ["second_layer", "third_layer"] - res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers) - - utils.compare_values(dict, type(res), "Wrong result type for layer manifolds") - utils.compare_values(3, len(res), "Wrong dictionary length for layer manifolds") - utils.compare_values( - set(["input"] + layers), - set(res.keys()), - "Wrong dictionary keys for layer manifolds", - ) - - for key, plot in res.items(): - utils.compare_values( - matplotlib.figure.Figure, - type(plot), - f"Wrong value type for key {key} in layer manifolds", - ) - - # Test for recurrent layer manifolds visualization using embeddings - data = torch.randn((20, 1, 256)) - labels = torch.randn(20) - model = utils.create_testing_model_lstm() - res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels) - - utils.compare_values(dict, type(res), "Wrong result type for recurrent layer manifolds") - utils.compare_values(2, len(res), "Wrong dictionary length for recurrent layer manifolds") - utils.compare_values( - set(layers), - set(res.keys()), - "Wrong dictionary keys for recurrent layer manifolds", - ) - for key, plot in res.items(): - utils.compare_values( - plotly.graph_objs.Figure, - type(plot), - f"Wrong value type for key {key} in recurrent layer manifolds", - ) From 842803ba82cec2abbc23733a8dc53318c36f9cfb Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:41:30 +0300 Subject: [PATCH 08/16] feat: change some filenames --- .idea/.gitignore | 3 +++ .idea/eXplain-NNs.iml | 15 +++++++++++++++ .idea/inspectionProfiles/Project_Default.xml | 6 ++++++ .../inspectionProfiles/profiles_settings.xml | 6 ++++++ .idea/misc.xml | 7 +++++++ .idea/modules.xml | 8 ++++++++ .idea/vcs.xml | 6 ++++++ README_eng.md | 2 +- README.md => README_ru.md | 0 ...ors.svg => strong_ai_in_industry_logo.svg} | 0 eXNN/bayes/api.py | 19 +++++++++---------- 11 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/eXplain-NNs.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml rename README.md => README_ru.md (100%) rename docs/{AIM-Strong_Sign_Norm-01_Colors.svg => strong_ai_in_industry_logo.svg} (100%) diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/eXplain-NNs.iml b/.idea/eXplain-NNs.iml new file mode 100644 index 0000000..28bc459 --- /dev/null +++ b/.idea/eXplain-NNs.iml @@ -0,0 +1,15 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..b95d966 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..c4c2635 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..ab7e721 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README_eng.md b/README_eng.md index 6203aec..035c3a3 100644 --- a/README_eng.md +++ b/README_eng.md @@ -7,7 +7,7 @@ [![Documentation](https://github.com/aimclub/eXplain-NNs/actions/workflows/pages/pages-build-deployment/badge.svg)](https://med-ai-lab.github.io/eXplain-NNs-documentation/) [![license](https://img.shields.io/github/license/aimclub/eXplain-NNs)](https://github.com/aimclub/eXplain-NNs/blob/main/LICENSE) -[![Rus](https://img.shields.io/badge/lang-ru-yellow.svg)](/README.md) +[![Rus](https://img.shields.io/badge/lang-ru-yellow.svg)](/README_ru) [![Mirror](https://img.shields.io/badge/mirror-GitLab-orange)](https://gitlab.actcognitive.org/itmo-sai-code/eXplain-NNs) # eXplain-NNs diff --git a/README.md b/README_ru.md similarity index 100% rename from README.md rename to README_ru.md diff --git a/docs/AIM-Strong_Sign_Norm-01_Colors.svg b/docs/strong_ai_in_industry_logo.svg similarity index 100% rename from docs/AIM-Strong_Sign_Norm-01_Colors.svg rename to docs/strong_ai_in_industry_logo.svg diff --git a/eXNN/bayes/api.py b/eXNN/bayes/api.py index d2910bb..fb2a053 100755 --- a/eXNN/bayes/api.py +++ b/eXNN/bayes/api.py @@ -1,6 +1,5 @@ from typing import Dict, Optional -import torch import torch.optim from eXNN.bayes.wrapper import create_dropout_bayesian_wrapper @@ -8,12 +7,12 @@ class DropoutBayesianWrapper: def __init__( - self, - model: torch.nn.Module, - mode: str, - p: Optional[float] = None, - a: Optional[float] = None, - b: Optional[float] = None, + self, + model: torch.nn.Module, + mode: str, + p: Optional[float] = None, + a: Optional[float] = None, + b: Optional[float] = None, ): """Class representing bayesian equivalent of a neural network. @@ -49,9 +48,9 @@ def predict(self, data, n_iter) -> Dict[str, torch.Tensor]: class GaussianBayesianWrapper: def __init__( - self, - model: torch.nn.Module, - sigma: float, + self, + model: torch.nn.Module, + sigma: float, ): """Class representing bayesian equivalent of a neural network. From 0b5a4cc7a4016be5b92ca6e0e75d3fec5ab38a41 Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:42:59 +0300 Subject: [PATCH 09/16] fix: change path to README_ru --- README_eng.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README_eng.md b/README_eng.md index 035c3a3..ca7c4e1 100644 --- a/README_eng.md +++ b/README_eng.md @@ -7,7 +7,7 @@ [![Documentation](https://github.com/aimclub/eXplain-NNs/actions/workflows/pages/pages-build-deployment/badge.svg)](https://med-ai-lab.github.io/eXplain-NNs-documentation/) [![license](https://img.shields.io/github/license/aimclub/eXplain-NNs)](https://github.com/aimclub/eXplain-NNs/blob/main/LICENSE) -[![Rus](https://img.shields.io/badge/lang-ru-yellow.svg)](/README_ru) +[![Rus](https://img.shields.io/badge/lang-ru-yellow.svg)](/README_ru.md) [![Mirror](https://img.shields.io/badge/mirror-GitLab-orange)](https://gitlab.actcognitive.org/itmo-sai-code/eXplain-NNs) # eXplain-NNs From 539c56ce2251abefea814fe77e2bb875e3dc20cf Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:46:27 +0300 Subject: [PATCH 10/16] feat: add .idea folder to ignored --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index b6e4761..a477fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea +/.idea +.idea/ \ No newline at end of file From cee2360ca42f28aa3bbaabe3eeed3547dfdb4b95 Mon Sep 17 00:00:00 2001 From: Asmoorr <93661487+Asmoorr@users.noreply.github.com> Date: Sun, 22 Dec 2024 03:46:57 +0300 Subject: [PATCH 11/16] fix: delete .idea directory --- .idea/.gitignore | 3 --- .idea/eXplain-NNs.iml | 15 --------------- .idea/inspectionProfiles/Project_Default.xml | 6 ------ .idea/inspectionProfiles/profiles_settings.xml | 6 ------ .idea/misc.xml | 7 ------- .idea/modules.xml | 8 -------- .idea/vcs.xml | 6 ------ 7 files changed, 51 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/eXplain-NNs.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 26d3352..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml diff --git a/.idea/eXplain-NNs.iml b/.idea/eXplain-NNs.iml deleted file mode 100644 index 28bc459..0000000 --- a/.idea/eXplain-NNs.iml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index b95d966..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index c4c2635..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index ab7e721..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From ffbd00099f6e23ffa5b17eac113bcc957700e8fd Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:54:59 +0300 Subject: [PATCH 12/16] fix: rename back to just README instead of README_eng --- README_eng.md => README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename README_eng.md => README.md (100%) diff --git a/README_eng.md b/README.md similarity index 100% rename from README_eng.md rename to README.md From 33b7cee837a69c744b5c54704f24e9c1bd8ad48e Mon Sep 17 00:00:00 2001 From: asmoorr <90kidex90@gmail.com> Date: Sun, 22 Dec 2024 03:56:51 +0300 Subject: [PATCH 13/16] fix: translate some test to russian --- README_ru.md | 101 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/README_ru.md b/README_ru.md index 8352786..e6417ec 100644 --- a/README_ru.md +++ b/README_ru.md @@ -7,132 +7,162 @@ [![Documentation](https://github.com/aimclub/eXplain-NNs/actions/workflows/pages/pages-build-deployment/badge.svg)](https://med-ai-lab.github.io/eXplain-NNs-documentation/) [![license](https://img.shields.io/github/license/aimclub/eXplain-NNs)](https://github.com/aimclub/eXplain-NNs/blob/main/LICENSE) -[![Eng](https://img.shields.io/badge/lang-en-red.svg)](/README_eng.md) +[![Eng](https://img.shields.io/badge/lang-en-red.svg)](/README) [![Mirror](https://img.shields.io/badge/mirror-GitLab-orange)](https://gitlab.actcognitive.org/itmo-sai-code/eXplain-NNs) # eXplain-NNs -Этот репозиторий содержит библиотеку eXplain-NNs — библиотеку с открытым исходным кодом с методами объяснимого ИИ (XAI) для анализа нейронных сетей. Эта библиотека предоставляет несколько методов XAI для анализа латентных пространств и оценки неопределенности. + +Этот репозиторий содержит библиотеку eXplain-NNs — библиотеку с открытым исходным кодом с методами объяснимого ИИ (XAI) +для анализа нейронных сетей. Эта библиотека предоставляет несколько методов XAI для анализа латентных пространств и +оценки неопределенности. ## Описание проекта ### Методы + Методы XAI, реализованные в библиотеке + 1. визуализация латентных пространств -1. гомологический анализ латентных пространств -1. оценка неопределенности с помощью байесианизации +2. гомологический анализ латентных пространств +3. оценка неопределенности с помощью байесианизации Таким образом, по сравнению с другими библиотеками объяснимого ИИ библиотека eXplain-NNs: + * Обеспечивает анализ гомологий латентных пространств -* Внедряет новый метод оценки неопределенности с помощью байесианизации XAI для анализа латентных пространств и оценки неопределенности. +* Внедряет новый метод оценки неопределенности с помощью байесианизации XAI для анализа латентных пространств и оценки + неопределенности. Детали [реализации методов](/docs/methods.md). -### Data Requirement -* The library supports only models that are: - * fully connected or convolutional - * designed for classification task +### Требования к данным + +* Библиотека поддерживает только модели, которые являются: + * полносвязные или конволюционные + * разработаны для задачи классификации + +## Установка -## Installation Требования: Python 3.8 + 1. [optional] создайте среду окружения Python, e.g. ``` $ conda create -n eXNN python=3.8 $ conda activate eXNN ``` -1. установите зависимости из [requirements.txt](/requirements.txt) +2. установите зависимости из [requirements.txt](/requirements.txt) ``` $ pip install -r requirements.txt ``` -1. установите библиотеку как пакет +3. установите библиотеку как пакет ``` $ python -m pip install git+ssh://git@github.com/Med-AI-Lab/eXplain-NNs ``` -Видео с процессом установки можно посмотреть [здесь](https://drive.google.com/file/d/1Sv8UiRwWfMLJ0kOSYHB_PgILHzNcqfs0/view?usp=sharing). - +Видео с процессом установки можно +посмотреть [здесь](https://drive.google.com/file/d/1Sv8UiRwWfMLJ0kOSYHB_PgILHzNcqfs0/view?usp=sharing). ## Development + Требования: Python 3.8 + 1. [optional] создайте среды окружения Python, e.g. ``` $ conda create -n eXNN python=3.8 $ conda activate eXNN ``` -1. клонируйте репозиторий и установите зависимости +2. клонируйте репозиторий и установите зависимости ``` $ git clone git@github.com:Med-AI-Lab/eXplain-NNs.git $ cd eXplain-NNs $ pip install -r requirements.txt ``` -1. запуск тестов +3. запуск тестов ``` - $ pytest tests/tests.py + $ pytest tests ``` -1. приведение стиля кода в соотвествие с PEP8 автоматически +4. приведение стиля кода в соответствие с PEP8 автоматически ``` $ make format ``` -1. проверка стиля кода на соотвествие с PEP8 +5. проверка стиля кода на соответствие с PEP8 ``` $ make check ``` -1. создание PyPi пакета локально +6. создание PyPi пакета локально ``` $ python3 -m pip install --upgrade build $ python3 -m build ``` ## Документация + [Документация](https://med-ai-lab.github.io/eXplain-NNs-documentation/) [API](https://med-ai-lab.github.io/eXplain-NNs-documentation/api_docs/eXNN.html) +## Примеры и туториалы -## Примеры и тьюториалы Мы предоставляем примеры разного уровня сложности: + * [минимальные] минималистичные примеры, представляющие наш API * [базовые] применение eXNN для простых задач, таких как классификация MNIST -* [сценарии использования] демонстрация использования eXplain-NN для решения различных проблем, возникающих в промышленных задачах. +* [сценарии использования] демонстрация использования eXplain-NN для решения различных проблем, возникающих в + промышленных задачах. ### Минимальные + Этот колаб содержит минималистическую демонстрацию нашего API на фиктивных объектах: [![minimal](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1lOiB50LppDiiRHTv184JMuQ2IvZ4I4rp?usp=sharing) ### Базовые + Вот колабы, демонстрирующие, как работать с разными модулями нашего API на простых задачах: -| Colab Link | Module | -| ------------- | ------------- | -| [![bayes](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ayd0IronxUIfnbAmWQLHiILG2qtBBpF4?usp=sharing)| bayes | -| [![topology](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T5ENfNaCIRI61LM2ZhtU8lfmvRmlfiEo?usp=sharing)| topology | -| [![visualization](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LJVdWTv-wcASSMX4is_E15TR7XJsT7W3?usp=sharing)| visualization | +| Colab Link | Module | +|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| [![bayes](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ayd0IronxUIfnbAmWQLHiILG2qtBBpF4?usp=sharing) | bayes | +| [![topology](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T5ENfNaCIRI61LM2ZhtU8lfmvRmlfiEo?usp=sharing) | topology | +| [![visualization](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LJVdWTv-wcASSMX4is_E15TR7XJsT7W3?usp=sharing) | visualization | ### Сценарии использования -В этом блоке представлены примеры использования eXplain-NN для решения различных вариантов использования в промышленных задачах. Для демонстрационных целей используются 3 задачи: + +В этом блоке представлены примеры использования eXplain-NN для решения различных вариантов использования в промышленных +задачах. Для демонстрационных целей используются 3 задачи: + * [спутник] классификация ландшафтов по спутниковым снимкам. * [электроника] классификация электронных компонентов и устройств * [ЭКГ] диагностика ЭКГ -| Colab Link | Task | Use Case | -| ------------- | ------------- | ------------- | -| [![CNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12ZJigH-0geGTefNXnCM5dQ71d4tqlf6L?usp=sharing)| спутник | Визуализация изменения многообразия данных от слоя к слою | -| [![adv](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1n50WUu2ZKZ6nrT9DuFD3q87m3yZvxkwm?usp=sharing) | спутник | Детекция adversarial данных | -| [![generalize](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mG-VrP7J7OoCvIQDl7n5YWEIdyfFg_0I?usp=sharing) | электроника | Оценка обобщающей способности нейронной сети | -| [![RNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aAtqxQLcOsSJJumfsmS9HGLgHrOFHlfk?usp=sharing) | ЭКГ | Визуализация изменения многообразия данных от слоя к слою | +| Colab Link | Task | Use Case | +|------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|-----------------------------------------------------------| +| [![CNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12ZJigH-0geGTefNXnCM5dQ71d4tqlf6L?usp=sharing) | спутник | Визуализация изменения многообразия данных от слоя к слою | +| [![adv](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1n50WUu2ZKZ6nrT9DuFD3q87m3yZvxkwm?usp=sharing) | спутник | Детекция adversarial данных | +| [![generalize](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mG-VrP7J7OoCvIQDl7n5YWEIdyfFg_0I?usp=sharing) | электроника | Оценка обобщающей способности нейронной сети | +| [![RNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aAtqxQLcOsSJJumfsmS9HGLgHrOFHlfk?usp=sharing) | ЭКГ | Визуализация изменения многообразия данных от слоя к слою | ## Как помочь проекту + [Инструкции](/docs/contribution.md). ## Организационное ### Аффилиация + [Университет ИТМО](https://en.itmo.ru/). ### Поддержка -Исследование проводится при поддержке [Исследовательского центра сильного искусственного интеллекта в промышленности]() [Университета ИТМО](https://itmo.ru) в рамках мероприятия программы центра: Разработка и испытания экспериментального образца библиотеки алгоритмов сильного ИИ в части алгоритмов объяснения результатов моделирования на данных с использованием семантики и терминологии предметной и проблемной областей в задачах с высокой неопределенностью, включая оценку неопределенности предсказаний моделей нейронных сетей, а также анализ и визуализацию межслойных трансформаций входного многообразия внутри нейронных сетей. + +Исследование проводится при +поддержке [Исследовательского центра сильного искусственного интеллекта в промышленности]() [Университета ИТМО](https://itmo.ru) +в рамках мероприятия программы центра: Разработка и испытания экспериментального образца библиотеки алгоритмов сильного +ИИ в части алгоритмов объяснения результатов моделирования на данных с использованием семантики и терминологии +предметной и проблемной областей в задачах с высокой неопределенностью, включая оценку неопределенности предсказаний +моделей нейронных сетей, а также анализ и визуализацию межслойных трансформаций входного многообразия внутри нейронных +сетей. ### Разработчики + * А. Ватьян - тим лид * Н. Гусарова - научный руководитель * И. Томилов @@ -140,5 +170,6 @@ * К. Никулина ## Контакты + * Александра Ватьян alexvatyan@gmail.com по вопросам сотрудничества * Татьяна Полевая tpolevaya@itmo.ru по техническим вопросам From dfd99c033cc9ad01c9e50626af6b4d59c4cb45cc Mon Sep 17 00:00:00 2001 From: Asmoorr <90kidex90@gmail.com> Date: Mon, 13 Jan 2025 20:45:14 +0300 Subject: [PATCH 14/16] fix: remove extra file with translation --- README_ru.md | 175 --------------------------------------------------- 1 file changed, 175 deletions(-) delete mode 100644 README_ru.md diff --git a/README_ru.md b/README_ru.md deleted file mode 100644 index e6417ec..0000000 --- a/README_ru.md +++ /dev/null @@ -1,175 +0,0 @@ -

- -

- -[![SAI](https://github.com/ITMO-NSS-team/open-source-ops/blob/master/badges/SAI_badge_flat.svg)](https://sai.itmo.ru/) -[![ITMO](https://github.com/ITMO-NSS-team/open-source-ops/blob/master/badges/ITMO_badge_flat_rus.svg)](https://en.itmo.ru/en/) - -[![Documentation](https://github.com/aimclub/eXplain-NNs/actions/workflows/pages/pages-build-deployment/badge.svg)](https://med-ai-lab.github.io/eXplain-NNs-documentation/) -[![license](https://img.shields.io/github/license/aimclub/eXplain-NNs)](https://github.com/aimclub/eXplain-NNs/blob/main/LICENSE) -[![Eng](https://img.shields.io/badge/lang-en-red.svg)](/README) -[![Mirror](https://img.shields.io/badge/mirror-GitLab-orange)](https://gitlab.actcognitive.org/itmo-sai-code/eXplain-NNs) - -# eXplain-NNs - -Этот репозиторий содержит библиотеку eXplain-NNs — библиотеку с открытым исходным кодом с методами объяснимого ИИ (XAI) -для анализа нейронных сетей. Эта библиотека предоставляет несколько методов XAI для анализа латентных пространств и -оценки неопределенности. - -## Описание проекта - -### Методы - -Методы XAI, реализованные в библиотеке - -1. визуализация латентных пространств -2. гомологический анализ латентных пространств -3. оценка неопределенности с помощью байесианизации - -Таким образом, по сравнению с другими библиотеками объяснимого ИИ библиотека eXplain-NNs: - -* Обеспечивает анализ гомологий латентных пространств -* Внедряет новый метод оценки неопределенности с помощью байесианизации XAI для анализа латентных пространств и оценки - неопределенности. - -Детали [реализации методов](/docs/methods.md). - -### Требования к данным - -* Библиотека поддерживает только модели, которые являются: - * полносвязные или конволюционные - * разработаны для задачи классификации - -## Установка - -Требования: Python 3.8 - -1. [optional] создайте среду окружения Python, e.g. - ``` - $ conda create -n eXNN python=3.8 - $ conda activate eXNN - ``` -2. установите зависимости из [requirements.txt](/requirements.txt) - ``` - $ pip install -r requirements.txt - ``` -3. установите библиотеку как пакет - ``` - $ python -m pip install git+ssh://git@github.com/Med-AI-Lab/eXplain-NNs - ``` - -Видео с процессом установки можно -посмотреть [здесь](https://drive.google.com/file/d/1Sv8UiRwWfMLJ0kOSYHB_PgILHzNcqfs0/view?usp=sharing). - -## Development - -Требования: Python 3.8 - -1. [optional] создайте среды окружения Python, e.g. - ``` - $ conda create -n eXNN python=3.8 - $ conda activate eXNN - ``` -2. клонируйте репозиторий и установите зависимости - ``` - $ git clone git@github.com:Med-AI-Lab/eXplain-NNs.git - $ cd eXplain-NNs - $ pip install -r requirements.txt - ``` -3. запуск тестов - ``` - $ pytest tests - ``` -4. приведение стиля кода в соответствие с PEP8 автоматически - ``` - $ make format - ``` -5. проверка стиля кода на соответствие с PEP8 - ``` - $ make check - ``` -6. создание PyPi пакета локально - ``` - $ python3 -m pip install --upgrade build - $ python3 -m build - ``` - -## Документация - -[Документация](https://med-ai-lab.github.io/eXplain-NNs-documentation/) - -[API](https://med-ai-lab.github.io/eXplain-NNs-documentation/api_docs/eXNN.html) - -## Примеры и туториалы - -Мы предоставляем примеры разного уровня сложности: - -* [минимальные] минималистичные примеры, представляющие наш API -* [базовые] применение eXNN для простых задач, таких как классификация MNIST -* [сценарии использования] демонстрация использования eXplain-NN для решения различных проблем, возникающих в - промышленных задачах. - -### Минимальные - -Этот колаб содержит минималистическую демонстрацию нашего API на фиктивных объектах: - -[![minimal](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1lOiB50LppDiiRHTv184JMuQ2IvZ4I4rp?usp=sharing) - -### Базовые - -Вот колабы, демонстрирующие, как работать с разными модулями нашего API на простых задачах: - -| Colab Link | Module | -|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| [![bayes](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ayd0IronxUIfnbAmWQLHiILG2qtBBpF4?usp=sharing) | bayes | -| [![topology](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T5ENfNaCIRI61LM2ZhtU8lfmvRmlfiEo?usp=sharing) | topology | -| [![visualization](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LJVdWTv-wcASSMX4is_E15TR7XJsT7W3?usp=sharing) | visualization | - -### Сценарии использования - -В этом блоке представлены примеры использования eXplain-NN для решения различных вариантов использования в промышленных -задачах. Для демонстрационных целей используются 3 задачи: - -* [спутник] классификация ландшафтов по спутниковым снимкам. -* [электроника] классификация электронных компонентов и устройств -* [ЭКГ] диагностика ЭКГ - -| Colab Link | Task | Use Case | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|-----------------------------------------------------------| -| [![CNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12ZJigH-0geGTefNXnCM5dQ71d4tqlf6L?usp=sharing) | спутник | Визуализация изменения многообразия данных от слоя к слою | -| [![adv](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1n50WUu2ZKZ6nrT9DuFD3q87m3yZvxkwm?usp=sharing) | спутник | Детекция adversarial данных | -| [![generalize](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mG-VrP7J7OoCvIQDl7n5YWEIdyfFg_0I?usp=sharing) | электроника | Оценка обобщающей способности нейронной сети | -| [![RNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aAtqxQLcOsSJJumfsmS9HGLgHrOFHlfk?usp=sharing) | ЭКГ | Визуализация изменения многообразия данных от слоя к слою | - -## Как помочь проекту - -[Инструкции](/docs/contribution.md). - -## Организационное - -### Аффилиация - -[Университет ИТМО](https://en.itmo.ru/). - -### Поддержка - -Исследование проводится при -поддержке [Исследовательского центра сильного искусственного интеллекта в промышленности]() [Университета ИТМО](https://itmo.ru) -в рамках мероприятия программы центра: Разработка и испытания экспериментального образца библиотеки алгоритмов сильного -ИИ в части алгоритмов объяснения результатов моделирования на данных с использованием семантики и терминологии -предметной и проблемной областей в задачах с высокой неопределенностью, включая оценку неопределенности предсказаний -моделей нейронных сетей, а также анализ и визуализацию межслойных трансформаций входного многообразия внутри нейронных -сетей. - -### Разработчики - -* А. Ватьян - тим лид -* Н. Гусарова - научный руководитель -* И. Томилов -* Т. Полевая -* К. Никулина - -## Контакты - -* Александра Ватьян alexvatyan@gmail.com по вопросам сотрудничества -* Татьяна Полевая tpolevaya@itmo.ru по техническим вопросам From 4e7eaaca6956e3128efa235980f4a5335c03c198 Mon Sep 17 00:00:00 2001 From: Asmoorr <90kidex90@gmail.com> Date: Mon, 13 Jan 2025 23:12:56 +0300 Subject: [PATCH 15/16] fix: change line length down to 100 or less --- eXNN/bayes/wrapper.py | 15 ++++++++++----- eXNN/topology/homologies.py | 6 ++++-- eXNN/topology/metrics.py | 3 ++- tests/test_bayesian/test_bayes.py | 9 ++++++--- tests/test_visualization/test_visualization.py | 3 ++- tests/utils/test_utils.py | 5 ++--- 6 files changed, 26 insertions(+), 15 deletions(-) diff --git a/eXNN/bayes/wrapper.py b/eXNN/bayes/wrapper.py index 2681faf..456b0c6 100644 --- a/eXNN/bayes/wrapper.py +++ b/eXNN/bayes/wrapper.py @@ -13,10 +13,12 @@ class ModuleBayesianWrapper(nn.Module): Args: layer (nn.Module): The layer to wrap (e.g., nn.Linear, nn.Conv2d). - p (Optional[float]): Dropout probability for simple dropout. Mutually exclusive with `a`, `b`, and `sigma`. + p (Optional[float]): Dropout probability for simple dropout. + Mutually exclusive with `a`, `b`, and `sigma`. a (Optional[float]): Alpha parameter for Beta distribution dropout. Used with `b`. b (Optional[float]): Beta parameter for Beta distribution dropout. Used with `a`. - sigma (Optional[float]): Standard deviation for Gaussian noise. Mutually exclusive with `p`, `a`, and `b`. + sigma (Optional[float]): Standard deviation for Gaussian noise. + Mutually exclusive with `p`, `a`, and `b`. """ def __init__( @@ -171,7 +173,8 @@ def mean_forward( n_iter (int): Number of stochastic forward passes. Returns: - torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. """ results = [] for _ in range(n_iter): @@ -234,7 +237,8 @@ def mean_forward( n_iter (int): Number of stochastic forward passes. Returns: - torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. """ results = [] for _ in range(n_iter): @@ -294,7 +298,8 @@ def mean_forward( n_iter (int): Number of stochastic forward passes. Returns: - torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs. + torch.Tensor: A tensor containing the mean (dim=0) and + standard deviation (dim=1) of outputs. """ results = [] for _ in range(n_iter): diff --git a/eXNN/topology/homologies.py b/eXNN/topology/homologies.py index e5a0196..862984e 100644 --- a/eXNN/topology/homologies.py +++ b/eXNN/topology/homologies.py @@ -44,7 +44,8 @@ def _diagram_to_barcode(plot) -> Dict[str, np.ndarray]: plot: The plot object containing persistence diagram data. Returns: - Dict[str, np.ndarray]: A dictionary where keys are homology types, and values are arrays of intervals. + Dict[str, np.ndarray]: A dictionary where keys are homology types, and + values are arrays of intervals. """ data = plot["data"] homologies = {} @@ -88,7 +89,8 @@ def plot_barcode(barcode: Dict[str, np.ndarray]) -> plt.Figure: return fig -def compute_data_barcode(data: torch.Tensor, hom_type: str, coefficient_type: str) -> Dict[str, np.ndarray]: +def compute_data_barcode(data: torch.Tensor, hom_type: str, coefficient_type: str) -> Dict[ + str, np.ndarray]: """ Computes a barcode for the given data using persistent homology. diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py index 85db57b..6013812 100644 --- a/eXNN/topology/metrics.py +++ b/eXNN/topology/metrics.py @@ -40,7 +40,8 @@ def compute_metric(barcode: Dict[str, np.ndarray], metric_name: str = None): Args: barcode (Dict[str, np.ndarray]): The barcode to compute metrics for. - metric_name (str, optional): The specific metric name to compute. If None, all metrics are computed. + metric_name (str, optional): The specific metric name to compute. + If None, all metrics are computed. Returns: float or Dict[str, float]: The computed metric(s). diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py index 290fb7c..7861998 100644 --- a/tests/test_bayesian/test_bayes.py +++ b/tests/test_bayesian/test_bayes.py @@ -29,13 +29,16 @@ def _test_bayes_prediction(mode: str): if mode != 'gauss': res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) else: - res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, + n_iter=n_iter) utils.compare_values(dict, type(res), "Wrong result type") utils.compare_values(2, len(res), "Wrong dictionary length") utils.compare_values({"mean", "std"}, set(res.keys()), "Wrong dictionary keys") - utils.compare_values(torch.Size([n, num_classes]), res["mean"].shape, "Wrong mean shape") - utils.compare_values(torch.Size([n, num_classes]), res["std"].shape, "Wrong mean std") + utils.compare_values(torch.Size([n, num_classes]), res["mean"].shape, + "Wrong mean shape") + utils.compare_values(torch.Size([n, num_classes]), res["std"].shape, + "Wrong mean std") def test_basic_bayes_wrapper(): diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py index 31a6206..859f0b3 100644 --- a/tests/test_visualization/test_visualization.py +++ b/tests/test_visualization/test_visualization.py @@ -51,7 +51,8 @@ def test_embed_visualization(): model = utils.create_testing_model_lstm() layers = ["second_layer", "third_layer"] - res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels) + res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, + labels=labels) utils.compare_values(dict, type(res), "Wrong result type") utils.compare_values(2, len(res), "Wrong dictionary length") diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 47c85f7..ef0f82c 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -29,9 +29,8 @@ def compare_values(expected, got, message_header=None): Raises: AssertionError: If the values do not match, an error with the message header is raised. """ - assert ( - expected == got - ), f"{_form_message_header(message_header)}: expected {expected}, got {got}" + assert (expected == got), \ + f"{_form_message_header(message_header)}: expected {expected}, got {got}" def create_testing_data(): From c57bd3b86d8a3d51391cb1a29cc5080bcc4a59f3 Mon Sep 17 00:00:00 2001 From: Asmoorr <90kidex90@gmail.com> Date: Tue, 14 Jan 2025 13:12:08 +0300 Subject: [PATCH 16/16] fix: improve PEP8 compliance with additional code style adjustments --- eXNN/topology/homologies.py | 7 +++-- tests/test_topology/test_barcode_general.py | 35 ++++++++++++++++++--- tests/test_topology/test_barcode_metrics.py | 17 ++++++++-- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/eXNN/topology/homologies.py b/eXNN/topology/homologies.py index 862984e..19f470d 100644 --- a/eXNN/topology/homologies.py +++ b/eXNN/topology/homologies.py @@ -89,8 +89,11 @@ def plot_barcode(barcode: Dict[str, np.ndarray]) -> plt.Figure: return fig -def compute_data_barcode(data: torch.Tensor, hom_type: str, coefficient_type: str) -> Dict[ - str, np.ndarray]: +def compute_data_barcode( + data: torch.Tensor, + hom_type: str, + coefficient_type: str +) -> Dict[str, np.ndarray]: """ Computes a barcode for the given data using persistent homology. diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py index 174ce28..22f5b57 100644 --- a/tests/test_topology/test_barcode_general.py +++ b/tests/test_topology/test_barcode_general.py @@ -73,11 +73,36 @@ def test_all_barcodes(): # Test NN barcodes model = utils.create_testing_model() layers = ["second_layer", "third_layer"] - nn_barcodes = topology_api.get_nn_barcodes(model, data, layers, "standard", "3") - utils.compare_values(dict, type(nn_barcodes), "Wrong result type for NN barcodes") - utils.compare_values(2, len(nn_barcodes), "Wrong dictionary length for NN barcodes") - utils.compare_values(set(layers), set(nn_barcodes.keys()), "Wrong dictionary keys for NN barcodes") + nn_barcodes = topology_api.get_nn_barcodes( + model, + data, + layers, + "standard", + "3" + ) + + utils.compare_values( + dict, + type(nn_barcodes), + "Wrong result type for NN barcodes" + ) + + utils.compare_values( + 2, + len(nn_barcodes), + "Wrong dictionary length for NN barcodes" + ) + + utils.compare_values( + set(layers), + set(nn_barcodes.keys()), + "Wrong dictionary keys for NN barcodes" + ) # Test barcode plot barcode_plot = topology_api.plot_barcode(res) - utils.compare_values(matplotlib.figure.Figure, type(barcode_plot), "Wrong result type for barcode plot") + utils.compare_values( + matplotlib.figure.Figure, + type(barcode_plot), + "Wrong result type for barcode plot" + ) diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py index 53a3a29..0a77073 100644 --- a/tests/test_topology/test_barcode_metrics.py +++ b/tests/test_topology/test_barcode_metrics.py @@ -75,13 +75,24 @@ def test_barcode_evaluate_all_metrics_and_individual(): ] # Check that the dictionary keys match the expected metric names - utils.compare_values(all_metric_names, sorted(result.keys()), "Wrong dictionary keys") + utils.compare_values( + all_metric_names, + sorted(result.keys()), + "Wrong dictionary keys" + ) # Ensure all metric values are floats for name, value in result.items(): - utils.compare_values(float, type(value), f"Wrong result type for metric {name}") + utils.compare_values( + float, type(value), + f"Wrong result type for metric {name}" + ) # Test for evaluating individual metrics for metric_name in all_metric_names: individual_result = topology_api.evaluate_barcode(barcode, metric_name=metric_name) - utils.compare_values(float, type(individual_result), f"Wrong result type for individual metric {metric_name}") + utils.compare_values( + float, + type(individual_result), + f"Wrong result type for individual metric {metric_name}" + )