From e4f76a331bea5c6e1174bcca1c762d7e8a9ac0fc Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 02:18:03 +0300
Subject: [PATCH 01/16] feat: add doc strings for functions
---
eXNN/topology/metrics.py | 224 +++++++++++++++++++++++++++++++++------
1 file changed, 191 insertions(+), 33 deletions(-)
diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py
index 7b5f71a..d5b5111 100644
--- a/eXNN/topology/metrics.py
+++ b/eXNN/topology/metrics.py
@@ -1,25 +1,30 @@
import heapq
-
import numpy as np
def _get_available_metrics():
+ """
+ Returns a dictionary of available metric names and their corresponding functions.
+
+ Returns:
+ dict: Mapping of metric names to functions.
+ """
return {
- # absolute length based metrics
+ # Absolute length-based metrics
"max_length": _compute_longest_interval_metric,
"mean_length": _compute_length_mean_metric,
"median_length": _compute_length_median_metric,
"stdev_length": _compute_length_stdev_metric,
"sum_length": _compute_length_sum_metric,
- # relative length based metrics
+ # Relative length-based metrics
"ratio_2_1": _compute_two_to_one_ratio_metric,
"ratio_3_1": _compute_three_to_one_ratio_metric,
- # entopy based metrics
+ # Entropy-based metrics
"h": _compute_entropy_metric,
"normh": _compute_normed_entropy_metric,
- # signal to noise ration
+ # Signal-to-noise ratio
"snr": _compute_snr_metric,
- # birth-death based metrics
+ # Birth-death based metrics
"mean_birth": _compute_births_mean_metric,
"stdev_birth": _compute_births_stdev_metric,
"mean_death": _compute_deaths_mean_metric,
@@ -28,119 +33,272 @@ def _get_available_metrics():
def compute_metric(barcode, metric_name=None):
+ """
+ Compute the specified metric or all available metrics for a given barcode.
+
+ Args:
+ barcode (dict): The barcode data containing persistent homology intervals.
+ metric_name (str, optional): The name of the metric to compute. Defaults to None.
+
+ Returns:
+ dict or float: A dictionary of all metrics if metric_name is None, otherwise the value of the specified metric.
+ """
metrics = _get_available_metrics()
if metric_name is None:
- return {name: fn(barcode) for (name, fn) in metrics.items()}
- else:
- return metrics[metric_name](barcode)
+ return {name: fn(barcode) for name, fn in metrics.items()}
+ return metrics[metric_name](barcode)
def _get_lengths(barcode):
+ """
+ Compute lengths of intervals in the H0 component of the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ list: A list of interval lengths.
+ """
diag = barcode["H0"]
return [d[1] - d[0] for d in diag]
def _compute_longest_interval_metric(barcode):
+ """
+ Compute the length of the longest interval in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The length of the longest interval.
+ """
lengths = _get_lengths(barcode)
return np.max(lengths).item()
def _compute_length_mean_metric(barcode):
+ """
+ Compute the mean length of intervals in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The mean length of intervals.
+ """
lengths = _get_lengths(barcode)
return np.mean(lengths).item()
def _compute_length_median_metric(barcode):
+ """
+ Compute the median length of intervals in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The median length of intervals.
+ """
lengths = _get_lengths(barcode)
return np.median(lengths).item()
def _compute_length_stdev_metric(barcode):
+ """
+ Compute the standard deviation of interval lengths in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The standard deviation of interval lengths.
+ """
lengths = _get_lengths(barcode)
return np.std(lengths).item()
def _compute_length_sum_metric(barcode):
+ """
+ Compute the sum of interval lengths in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The sum of interval lengths.
+ """
lengths = _get_lengths(barcode)
return np.sum(lengths).item()
-# Proportion between the longest intervals: 2/1 ratio, 3/1 ratio
def _compute_two_to_one_ratio_metric(barcode):
+ """
+ Compute the ratio of the second longest to the longest interval in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The 2-to-1 ratio of interval lengths.
+ """
lengths = _get_lengths(barcode)
value = heapq.nlargest(2, lengths)[1] / lengths[0]
- return value.item()
+ return value
def _compute_three_to_one_ratio_metric(barcode):
+ """
+ Compute the ratio of the third longest to the longest interval in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The 3-to-1 ratio of interval lengths.
+ """
lengths = _get_lengths(barcode)
value = heapq.nlargest(3, lengths)[2] / lengths[0]
- return value.item()
+ return value
+
+def _get_entropy(values, normalize):
+ """
+ Compute the entropy of a set of values.
-# Compute the persistent entropy and normed persistent entropy
-def _get_entropy(values, normalize: bool):
+ Args:
+ values (list): The values for which to compute entropy.
+ normalize (bool): Whether to normalize the entropy.
+
+ Returns:
+ float: The computed entropy.
+ """
values_sum = np.sum(values)
- entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum)))
+ entropy = -np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum)))
if normalize:
- entropy = entropy / np.log(values_sum)
+ entropy /= np.log(values_sum)
return entropy
def _compute_entropy_metric(barcode):
+ """
+ Compute the persistent entropy of the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The persistent entropy.
+ """
return _get_entropy(_get_lengths(barcode), normalize=False).item()
def _compute_normed_entropy_metric(barcode):
+ """
+ Compute the normalized persistent entropy of the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The normalized persistent entropy.
+ """
return _get_entropy(_get_lengths(barcode), normalize=True).item()
-# Compute births
def _get_births(barcode):
+ """
+ Extract the birth times from the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ np.ndarray: The birth times.
+ """
diag = barcode["H0"]
return np.array([x[0] for x in diag])
-# Comput deaths
def _get_deaths(barcode):
- diag = barcode["H0"]
- return np.array([x[1] for x in diag])
+ """
+ Extract the death times from the barcode.
+ Args:
+ barcode (dict): The barcode data.
-# def _get_birth(barcode, dim):
-# diag = barcode['H0']
-# temp = np.array([x[0] for x in diag if x[2] == dim])
-# return temp[0]
+ Returns:
+ np.ndarray: The death times.
+ """
+ diag = barcode["H0"]
+ return np.array([x[1] for x in diag])
-# def _get_death(barcode, dim):
-# diag = barcode['H0']
-# temp = np.array([x[1] for x in diag if x[2] == dim])
-# return temp[-1]
+def _compute_snr_metric(barcode):
+ """
+ Compute the signal-to-noise ratio (SNR) of the barcode.
+ Args:
+ barcode (dict): The barcode data.
-# Compute SNR
-def _compute_snr_metric(barcode):
+ Returns:
+ float: The SNR value.
+ """
births = _get_births(barcode)
deaths = _get_deaths(barcode)
signal = np.mean(deaths - births)
noise = np.std(births)
- snr = signal / noise
- return snr.item()
+ return (signal / noise).item()
-# Compute the birth-death pair indices: Birth mean, birth stdev, death mean, death stdev
def _compute_births_mean_metric(barcode):
+ """
+ Compute the mean of birth times in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The mean birth time.
+ """
return np.mean(_get_births(barcode)).item()
def _compute_births_stdev_metric(barcode):
+ """
+ Compute the standard deviation of birth times in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The standard deviation of birth times.
+ """
return np.std(_get_births(barcode)).item()
def _compute_deaths_mean_metric(barcode):
+ """
+ Compute the mean of death times in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The mean death time.
+ """
return np.mean(_get_deaths(barcode)).item()
def _compute_deaths_stdev_metric(barcode):
+ """
+ Compute the standard deviation of death times in the barcode.
+
+ Args:
+ barcode (dict): The barcode data.
+
+ Returns:
+ float: The standard deviation of death times.
+ """
return np.std(_get_deaths(barcode)).item()
From 54192d1079c33af1676485a5af59f8754e0dd729 Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 02:26:41 +0300
Subject: [PATCH 02/16] feat: add doc strings, breaklines; rename some
variables
---
eXNN/topology/homologies.py | 170 ++++++++++++++++++++----------------
1 file changed, 96 insertions(+), 74 deletions(-)
diff --git a/eXNN/topology/homologies.py b/eXNN/topology/homologies.py
index 6137ed7..e5a0196 100644
--- a/eXNN/topology/homologies.py
+++ b/eXNN/topology/homologies.py
@@ -10,121 +10,143 @@
)
-def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str):
+def _get_activation(model: torch.nn.Module, x: torch.Tensor, layer: str) -> torch.Tensor:
+ """
+ Extracts the activation output of a specified layer from a given model.
+
+ Args:
+ model (torch.nn.Module): The neural network model.
+ x (torch.Tensor): The input tensor to the model.
+ layer (str): The name of the layer to extract activation from.
+
+ Returns:
+ torch.Tensor: The activation output of the specified layer.
+ """
activation = {}
def store_output(name):
- def hook(model, input, output):
+ def hook(_, __, output):
activation[name] = output.detach()
return hook
- h1 = getattr(model, layer).register_forward_hook(store_output(layer))
+ hook_handle = getattr(model, layer).register_forward_hook(store_output(layer))
model.forward(x)
- h1.remove()
+ hook_handle.remove()
return activation[layer]
-def _diagram_to_barcode(plot):
+def _diagram_to_barcode(plot) -> Dict[str, np.ndarray]:
+ """
+ Converts a persistence diagram plot into a barcode representation.
+
+ Args:
+ plot: The plot object containing persistence diagram data.
+
+ Returns:
+ Dict[str, np.ndarray]: A dictionary where keys are homology types, and values are arrays of intervals.
+ """
data = plot["data"]
homologies = {}
- for h in data:
- if h["name"] is None:
+ for homology in data:
+ if homology["name"] is None:
continue
- homologies[h["name"]] = list(zip(h["x"], h["y"]))
+ homologies[homology["name"]] = list(zip(homology["x"], homology["y"]))
- for h in homologies.keys():
- homologies[h] = sorted(homologies[h], key=lambda x: x[0])
+ for key in homologies.keys():
+ homologies[key] = sorted(homologies[key], key=lambda x: x[0])
return homologies
-def plot_barcode(barcode: Dict[str, np.ndarray]):
+def plot_barcode(barcode: Dict[str, np.ndarray]) -> plt.Figure:
+ """
+ Plots a barcode diagram from given intervals.
+
+ Args:
+ barcode (Dict[str, np.ndarray]): A dictionary containing homology types and their intervals.
+
+ Returns:
+ plt.Figure: The matplotlib figure containing the barcode diagram.
+ """
homologies = list(barcode.keys())
- nplots = len(homologies)
- fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots)))
+ np_lots = len(homologies)
+ fig, axes = plt.subplots(np_lots, figsize=(15, min(10, 5 * np_lots)))
- if nplots == 1:
- ax = [ax]
+ if np_lots == 1:
+ axes = [axes]
- for i in range(nplots):
- name = homologies[i]
- ax[i].set_title(name)
- ax[i].set_ylim([-0.05, 1.05])
+ for i, name in enumerate(homologies):
+ axes[i].set_title(name)
+ axes[i].set_ylim([-0.05, 1.05])
bars = barcode[name]
n = len(bars)
- for j in range(n):
- bar = bars[j]
- ax[i].plot([bar[0], bar[1]], [j / n, j / n], color="black")
- labels = ["" for _ in range(len(ax[i].get_yticklabels()))]
- ax[i].set_yticks(ax[i].get_yticks())
- ax[i].set_yticklabels(labels)
-
- if nplots == 1:
- ax = ax[0]
+ for j, bar in enumerate(bars):
+ axes[i].plot([bar[0], bar[1]], [j / n, j / n], color="black")
+ axes[i].set_yticks([])
+
plt.close(fig)
return fig
-def compute_data_barcode(data: torch.Tensor, hom_type: str, coefs_type: str):
+def compute_data_barcode(data: torch.Tensor, hom_type: str, coefficient_type: str) -> Dict[str, np.ndarray]:
+ """
+ Computes a barcode for the given data using persistent homology.
+
+ Args:
+ data (torch.Tensor): The input data for barcode computation.
+ hom_type (str): The type of homology to use ("standard", "sparse", "weak").
+ coefficient_type (str): The coefficient field for homology computation.
+
+ Returns:
+ Dict[str, np.ndarray]: The computed barcode as a dictionary.
+
+ Raises:
+ ValueError: If an invalid hom_type is provided.
+ """
if hom_type == "standard":
- VR = VietorisRipsPersistence(
+ vr = VietorisRipsPersistence(
homology_dimensions=[0],
collapse_edges=True,
- coeff=int(coefs_type),
+ coeff=int(coefficient_type),
)
elif hom_type == "sparse":
- VR = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefs_type))
+ vr = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefficient_type))
elif hom_type == "weak":
- VR = WeakAlphaPersistence(
+ vr = WeakAlphaPersistence(
homology_dimensions=[0],
collapse_edges=True,
- coeff=int(coefs_type),
+ coeff=int(coefficient_type),
)
else:
- raise Exception('hom_type must be one of: "standard", "sparse", "weak"!')
+ raise ValueError('hom_type must be one of: "standard", "sparse", "weak"!')
- if len(data.shape) > 2:
+ if data.ndim > 2:
data = torch.nn.Flatten()(data)
data = data.reshape(1, *data.shape)
- diagrams = VR.fit_transform(data)
- plot = VR.plot(diagrams)
+ diagrams = vr.fit_transform(data)
+ plot = vr.plot(diagrams)
return _diagram_to_barcode(plot)
def compute_nn_barcode(
- model: torch.nn.Module,
- x: torch.Tensor,
- layer: str,
- hom_type: str,
- coefs_type: str,
-):
- act = _get_activation(model, x, layer)
- return compute_data_barcode(act, hom_type, coefs_type)
-
-
-# def get_homologies_experimental(
-# model: torch.nn.Module,
-# x: torch.Tensor,
-# layer: str,
-# dimensions: Optional[List[int]] = None,
-# make_barplot: bool = True,
-# rm_empty: bool = True,
-# ):
-# act = _get_activation(model, x, layer)
-# act = act.reshape(1, *act.shape)
-# # Dimensions must not be outside layer dimensionality
-# N = act.shape[-1]
-# dimensions = dimensions if dimensions is not None else []
-# dimensions = [i if i >= 0 else N + i for i in dimensions]
-# dimensions = [i for i in dimensions if ((i >= 0) and (i < N))]
-# dimensions = list(set(dimensions))
-# VR = VietorisRipsPersistence(homology_dimensions=dimensions, collapse_edges=True)
-# diagrams = VR.fit_transform(act)
-# plot = VR.plot(diagrams)
-# if make_barplot:
-# barcode = _diagram_to_barcode(plot)
-# if rm_empty:
-# barcode = {key: val for key, val in barcode.items() if len(val) > 0}
-# return _plot_barcode(barcode)
-# else:
-# return plot
+ model: torch.nn.Module,
+ x: torch.Tensor,
+ layer: str,
+ hom_type: str,
+ coefficient_type: str,
+) -> Dict[str, np.ndarray]:
+ """
+ Computes a barcode for a specified layer in a neural network model.
+
+ Args:
+ model (torch.nn.Module): The neural network model.
+ x (torch.Tensor): The input data for the model.
+ layer (str): The layer to extract activations from.
+ hom_type (str): The type of homology to use ("standard", "sparse", "weak").
+ coefficient_type (str): The coefficient field for homology computation.
+
+ Returns:
+ Dict[str, np.ndarray]: The computed barcode as a dictionary.
+ """
+ activation = _get_activation(model, x, layer)
+ return compute_data_barcode(activation, hom_type, coefficient_type)
From d8c8fe4a6cf8669d08502617f9510ccd838da1f2 Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 02:32:52 +0300
Subject: [PATCH 03/16] fix: change return type to float
---
eXNN/topology/metrics.py | 183 ++++++++++++++++++++-------------------
1 file changed, 93 insertions(+), 90 deletions(-)
diff --git a/eXNN/topology/metrics.py b/eXNN/topology/metrics.py
index d5b5111..85db57b 100644
--- a/eXNN/topology/metrics.py
+++ b/eXNN/topology/metrics.py
@@ -1,13 +1,15 @@
import heapq
+from typing import Dict
+
import numpy as np
def _get_available_metrics():
"""
- Returns a dictionary of available metric names and their corresponding functions.
+ Returns a dictionary mapping metric names to their respective computation functions.
Returns:
- dict: Mapping of metric names to functions.
+ Dict[str, callable]: A dictionary of metric computation functions.
"""
return {
# Absolute length-based metrics
@@ -32,16 +34,16 @@ def _get_available_metrics():
}
-def compute_metric(barcode, metric_name=None):
+def compute_metric(barcode: Dict[str, np.ndarray], metric_name: str = None):
"""
- Compute the specified metric or all available metrics for a given barcode.
+ Computes specified or all metrics for a given barcode.
Args:
- barcode (dict): The barcode data containing persistent homology intervals.
- metric_name (str, optional): The name of the metric to compute. Defaults to None.
+ barcode (Dict[str, np.ndarray]): The barcode to compute metrics for.
+ metric_name (str, optional): The specific metric name to compute. If None, all metrics are computed.
Returns:
- dict or float: A dictionary of all metrics if metric_name is None, otherwise the value of the specified metric.
+ float or Dict[str, float]: The computed metric(s).
"""
metrics = _get_available_metrics()
if metric_name is None:
@@ -49,256 +51,257 @@ def compute_metric(barcode, metric_name=None):
return metrics[metric_name](barcode)
-def _get_lengths(barcode):
+def _get_lengths(barcode: Dict[str, np.ndarray]):
"""
- Compute lengths of intervals in the H0 component of the barcode.
+ Extracts lengths of intervals from a barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- list: A list of interval lengths.
+ List[float]: A list of interval lengths.
"""
diag = barcode["H0"]
return [d[1] - d[0] for d in diag]
-def _compute_longest_interval_metric(barcode):
+def _compute_longest_interval_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the length of the longest interval in the barcode.
+ Computes the maximum interval length in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The length of the longest interval.
+ float: The maximum interval length.
"""
lengths = _get_lengths(barcode)
- return np.max(lengths).item()
+ return float(np.max(lengths))
-def _compute_length_mean_metric(barcode):
+def _compute_length_mean_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the mean length of intervals in the barcode.
+ Computes the mean interval length in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The mean length of intervals.
+ float: The mean interval length.
"""
lengths = _get_lengths(barcode)
- return np.mean(lengths).item()
+ return float(np.mean(lengths))
-def _compute_length_median_metric(barcode):
+def _compute_length_median_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the median length of intervals in the barcode.
+ Computes the median interval length in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The median length of intervals.
+ float: The median interval length.
"""
lengths = _get_lengths(barcode)
- return np.median(lengths).item()
+ return float(np.median(lengths))
-def _compute_length_stdev_metric(barcode):
+def _compute_length_stdev_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the standard deviation of interval lengths in the barcode.
+ Computes the standard deviation of interval lengths in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
float: The standard deviation of interval lengths.
"""
lengths = _get_lengths(barcode)
- return np.std(lengths).item()
+ return float(np.std(lengths))
-def _compute_length_sum_metric(barcode):
+def _compute_length_sum_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the sum of interval lengths in the barcode.
+ Computes the sum of all interval lengths in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The sum of interval lengths.
+ float: The sum of all interval lengths.
"""
lengths = _get_lengths(barcode)
- return np.sum(lengths).item()
+ return float(np.sum(lengths))
-def _compute_two_to_one_ratio_metric(barcode):
+def _compute_two_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the ratio of the second longest to the longest interval in the barcode.
+ Computes the ratio of the second largest to the largest interval length.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The 2-to-1 ratio of interval lengths.
+ float: The ratio of the second largest to the largest interval length.
"""
lengths = _get_lengths(barcode)
value = heapq.nlargest(2, lengths)[1] / lengths[0]
- return value
+ return float(value)
-def _compute_three_to_one_ratio_metric(barcode):
+def _compute_three_to_one_ratio_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the ratio of the third longest to the longest interval in the barcode.
+ Computes the ratio of the third largest to the largest interval length.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The 3-to-1 ratio of interval lengths.
+ float: The ratio of the third largest to the largest interval length.
"""
lengths = _get_lengths(barcode)
value = heapq.nlargest(3, lengths)[2] / lengths[0]
- return value
+ return float(value)
-def _get_entropy(values, normalize):
+def _get_entropy(values: np.ndarray, normalize: bool) -> float:
"""
- Compute the entropy of a set of values.
+ Computes the entropy of a given distribution.
Args:
- values (list): The values for which to compute entropy.
+ values (np.ndarray): The values to compute entropy for.
normalize (bool): Whether to normalize the entropy.
Returns:
float: The computed entropy.
"""
values_sum = np.sum(values)
- entropy = -np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum)))
+ entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum)))
if normalize:
- entropy /= np.log(values_sum)
- return entropy
+ entropy = entropy / np.log(values_sum)
+ return float(entropy)
-def _compute_entropy_metric(barcode):
+def _compute_entropy_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the persistent entropy of the barcode.
+ Computes the persistent entropy of intervals in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
float: The persistent entropy.
"""
- return _get_entropy(_get_lengths(barcode), normalize=False).item()
+ return _get_entropy(_get_lengths(barcode), normalize=False)
-def _compute_normed_entropy_metric(barcode):
+def _compute_normed_entropy_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the normalized persistent entropy of the barcode.
+ Computes the normalized persistent entropy of intervals in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
float: The normalized persistent entropy.
"""
- return _get_entropy(_get_lengths(barcode), normalize=True).item()
+ return _get_entropy(_get_lengths(barcode), normalize=True)
-def _get_births(barcode):
+def _get_births(barcode: Dict[str, np.ndarray]) -> np.ndarray:
"""
- Extract the birth times from the barcode.
+ Extracts the birth times from the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- np.ndarray: The birth times.
+ np.ndarray: An array of birth times.
"""
diag = barcode["H0"]
return np.array([x[0] for x in diag])
-def _get_deaths(barcode):
+def _get_deaths(barcode: Dict[str, np.ndarray]) -> np.ndarray:
"""
- Extract the death times from the barcode.
+ Extracts the death times from the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- np.ndarray: The death times.
+ np.ndarray: An array of death times.
"""
diag = barcode["H0"]
return np.array([x[1] for x in diag])
-def _compute_snr_metric(barcode):
+def _compute_snr_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the signal-to-noise ratio (SNR) of the barcode.
+ Computes the signal-to-noise ratio (SNR) for the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The SNR value.
+ float: The computed SNR.
"""
births = _get_births(barcode)
deaths = _get_deaths(barcode)
signal = np.mean(deaths - births)
noise = np.std(births)
- return (signal / noise).item()
+ snr = signal / noise
+ return float(snr)
-def _compute_births_mean_metric(barcode):
+def _compute_births_mean_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the mean of birth times in the barcode.
+ Computes the mean of birth times in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The mean birth time.
+ float: The mean of birth times.
"""
- return np.mean(_get_births(barcode)).item()
+ return float(np.mean(_get_births(barcode)))
-def _compute_births_stdev_metric(barcode):
+def _compute_births_stdev_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the standard deviation of birth times in the barcode.
+ Computes the standard deviation of birth times in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
float: The standard deviation of birth times.
"""
- return np.std(_get_births(barcode)).item()
+ return float(np.std(_get_births(barcode)))
-def _compute_deaths_mean_metric(barcode):
+def _compute_deaths_mean_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the mean of death times in the barcode.
+ Computes the mean of death times in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
- float: The mean death time.
+ float: The mean of death times.
"""
- return np.mean(_get_deaths(barcode)).item()
+ return float(np.mean(_get_deaths(barcode)))
-def _compute_deaths_stdev_metric(barcode):
+def _compute_deaths_stdev_metric(barcode: Dict[str, np.ndarray]) -> float:
"""
- Compute the standard deviation of death times in the barcode.
+ Computes the standard deviation of death times in the barcode.
Args:
- barcode (dict): The barcode data.
+ barcode (Dict[str, np.ndarray]): The barcode to process.
Returns:
float: The standard deviation of death times.
"""
- return np.std(_get_deaths(barcode)).item()
+ return float(np.std(_get_deaths(barcode)))
From 4ac776c7601c3ba713a761d70d4ae985ce7f9e49 Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 02:37:34 +0300
Subject: [PATCH 04/16] feat: update doc strings
---
eXNN/topology/api.py | 58 ++++++++++++++++++++++----------------------
1 file changed, 29 insertions(+), 29 deletions(-)
diff --git a/eXNN/topology/api.py b/eXNN/topology/api.py
index 70e70cb..36d4433 100644
--- a/eXNN/topology/api.py
+++ b/eXNN/topology/api.py
@@ -10,21 +10,22 @@
def get_data_barcode(
data: torch.Tensor,
hom_type: str,
- coefs_type: str,
+ coefficient_type: str,
) -> Dict[str, np.ndarray]:
- """This function computes persistent homologies for a cloud of points as barcodes.
+ """
+ Computes persistent homologies for a cloud of points as barcodes.
Args:
- data (torch.Tensor): input data of shape NxC1x...xCk,
+ data (torch.Tensor): Input data of shape NxC1x...xCk,
where N is the number of data points,
- C1,...,Ck are dimensions of each data point
- hom_type (str): homotopy type
- coefs_type (str): coefficients type
+ C1,...,Ck are dimensions of each data point.
+ hom_type (str): Homotopy type.
+ coefficient_type (str): Coefficients type.
Returns:
- Dict[str, np.ndarray]: barcode
+ Dict[str, np.ndarray]: Barcode.
"""
- return homologies.compute_data_barcode(data, hom_type, coefs_type)
+ return homologies.compute_data_barcode(data, hom_type, coefficient_type)
def get_nn_barcodes(
@@ -32,40 +33,39 @@ def get_nn_barcodes(
data: torch.Tensor,
layers: List[str],
hom_type: str,
- coefs_type: str,
+ coefficient_type: str,
) -> Dict[str, Dict[str, np.ndarray]]:
"""
- The function plots persistent homologies for latent representations
- on different levels of the neural network as barcodes.
+ Computes persistent homologies for latent representations on different
+ levels of the neural network as barcodes.
Args:
- model (torch.nn.Module): neural network
- data (torch.Tensor): input data of shape NxC1x...xCk,
+ model (torch.nn.Module): Neural network.
+ data (torch.Tensor): Input data of shape NxC1x...xCk,
where N is the number of data points,
- C1,...,Ck are dimensions of each data point
- layers (List[str]): list of layers for visualization. Defaults to None.
- If None, visualization for all layers is performed
- hom_type (str): homotopy type
- coefs_type (str): coefficients type
+ C1,...,Ck are dimensions of each data point.
+ layers (List[str]): List of layers for visualization.
+ hom_type (str): Homotopy type.
+ coefficient_type (str): Coefficients type.
Returns:
- Dict[str, Dict[str, np.ndarray]]: dictionary with a barcode for each layer
+ Dict[str, Dict[str, np.ndarray]]: Dictionary with a barcode for each layer.
"""
res = {}
for layer in layers:
- res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefs_type)
+ res[layer] = homologies.compute_nn_barcode(model, data, layer, hom_type, coefficient_type)
return res
def plot_barcode(barcode: Dict[str, np.ndarray]) -> matplotlib.figure.Figure:
"""
- The function creates a plot of a persistent homologies barcode.
+ Creates a plot of a persistent homologies barcode.
Args:
- barcode (Dict[str, np.ndarray]): barcode
+ barcode (Dict[str, np.ndarray]): Barcode.
Returns:
- matplotlib.figure.Figure: a plot of the barcode
+ matplotlib.figure.Figure: Plot of the barcode.
"""
return homologies.plot_barcode(barcode)
@@ -74,15 +74,15 @@ def evaluate_barcode(
barcode: Dict[str, np.ndarray], metric_name: Optional[str] = None
) -> Union[float, Dict[str, float]]:
"""
- The function evaluates a persistent homologies barcode with a metric.
+ Evaluates a persistent homologies barcode with a metric.
Args:
- barcode (Dict[str, np.ndarray]): barcode
- metric_name (Optional[str]): metric name
- (if `None` all available metrics values are computed)
+ barcode (Dict[str, np.ndarray]): Barcode.
+ metric_name (Optional[str]): Metric name (if None, all available metrics
+ values are computed).
Returns:
- Union(float, Dict[str, float]): float if metric is specified
- or a dictionary with value of each available metric
+ Union[float, Dict[str, float]]: Float if metric is specified, or a
+ dictionary with values of each available metric.
"""
return metrics.compute_metric(barcode, metric_name)
From e8c2ef8028f9d785e8778d953748e9c3edf5ac80 Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 02:51:34 +0300
Subject: [PATCH 05/16] feat: add docstrings
---
eXNN/bayes/wrapper.py | 203 ++++++++++++++++++++++++++++++++++--------
1 file changed, 167 insertions(+), 36 deletions(-)
diff --git a/eXNN/bayes/wrapper.py b/eXNN/bayes/wrapper.py
index 8e1ed43..2681faf 100644
--- a/eXNN/bayes/wrapper.py
+++ b/eXNN/bayes/wrapper.py
@@ -3,18 +3,29 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
+import torch.nn.functional as functional
from torch.distributions import Beta
class ModuleBayesianWrapper(nn.Module):
+ """
+ A wrapper for neural network layers to apply Bayesian-style dropout or noise during training.
+
+ Args:
+ layer (nn.Module): The layer to wrap (e.g., nn.Linear, nn.Conv2d).
+ p (Optional[float]): Dropout probability for simple dropout. Mutually exclusive with `a`, `b`, and `sigma`.
+ a (Optional[float]): Alpha parameter for Beta distribution dropout. Used with `b`.
+ b (Optional[float]): Beta parameter for Beta distribution dropout. Used with `a`.
+ sigma (Optional[float]): Standard deviation for Gaussian noise. Mutually exclusive with `p`, `a`, and `b`.
+ """
+
def __init__(
- self,
- layer: nn.Module,
- p: Optional[float] = None,
- a: Optional[float] = None,
- b: Optional[float] = None,
- sigma: Optional[float] = None,
+ self,
+ layer: nn.Module,
+ p: Optional[float] = None,
+ a: Optional[float] = None,
+ b: Optional[float] = None,
+ sigma: Optional[float] = None,
):
super(ModuleBayesianWrapper, self).__init__()
@@ -36,6 +47,16 @@ def __init__(
self.p, self.a, self.b, self.sigma = p, a, b, sigma
def augment_weights(self, weights, bias):
+ """
+ Apply the specified noise or dropout to the weights and bias.
+
+ Args:
+ weights (torch.Tensor): The weights of the layer.
+ bias (torch.Tensor): The bias of the layer (can be None).
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: The augmented weights and bias.
+ """
# Check if dropout is chosen
if (self.p is not None) or (self.a is not None and self.b is not None):
@@ -45,10 +66,10 @@ def augment_weights(self, weights, bias):
else:
p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample()
- weights = F.dropout(weights, p, training=True)
+ weights = functional.dropout(weights, p, training=True)
if bias is not None:
# In layers we sometimes have the ability to set bias to None
- bias = F.dropout(bias, p, training=True)
+ bias = functional.dropout(bias, p, training=True)
else:
# If gauss is chosen, then apply it
@@ -60,11 +81,20 @@ def augment_weights(self, weights, bias):
return weights, bias
def forward(self, x):
+ """
+ Forward pass through the layer with augmented weights.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ """
weight, bias = self.augment_weights(self.layer.weight, self.layer.bias)
if isinstance(self.layer, nn.Linear):
- return F.linear(x, weight, bias)
+ return functional.linear(x, weight, bias)
elif type(self.layer) in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
return self.layer._conv_forward(x, weight, bias)
else:
@@ -72,6 +102,18 @@ def forward(self, x):
def replace_modules_with_wrapper(model, wrapper_module, params):
+ """
+ Recursively replaces layers in a model with a Bayesian wrapper.
+
+ Args:
+ model (nn.Module): The model containing layers to replace.
+ wrapper_module (type): The wrapper class.
+ params (dict): Parameters for the wrapper.
+
+ Returns:
+ nn.Module: The model with wrapped layers.
+ """
+
if type(model) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]:
return wrapper_module(model, **params)
@@ -88,12 +130,26 @@ def replace_modules_with_wrapper(model, wrapper_module, params):
class NetworkBayes(nn.Module):
+ """
+ Bayesian network with standard dropout.
+
+ Args:
+ model (nn.Module): The base model.
+ dropout_p (float): Dropout probability.
+ """
+
def __init__(
- self,
- model: nn.Module,
- dropout_p: float,
+ self,
+ model: nn.Module,
+ dropout_p: float,
):
+ """
+ Initialize the NetworkBayes with standard dropout.
+ Args:
+ model (nn.Module): The base model to wrap with Bayesian dropout.
+ dropout_p (float): Dropout probability for the Bayesian wrapper.
+ """
super(NetworkBayes, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
@@ -103,11 +159,20 @@ def __init__(
)
def mean_forward(
- self,
- data: torch.Tensor,
- n_iter: int,
+ self,
+ data: torch.Tensor,
+ n_iter: int,
):
+ """
+ Perform forward passes to estimate the mean and standard deviation of outputs.
+ Args:
+ data (torch.Tensor): Input tensor.
+ n_iter (int): Number of stochastic forward passes.
+
+ Returns:
+ torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs.
+ """
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
@@ -125,13 +190,29 @@ def mean_forward(
# calculate mean and std after applying bayesian with beta distribution
class NetworkBayesBeta(nn.Module):
+ """
+ Bayesian network with Beta distribution dropout.
+
+ Args:
+ model (nn.Module): The base model.
+ alpha (float): Alpha parameter for the Beta distribution.
+ beta (float): Beta parameter for the Beta distribution.
+ """
+
def __init__(
- self,
- model: torch.nn.Module,
- alpha: float,
- beta: float,
+ self,
+ model: torch.nn.Module,
+ alpha: float,
+ beta: float,
):
-
+ """
+ Initialize the NetworkBayesBeta with Beta distribution dropout.
+
+ Args:
+ model (nn.Module): The base model to wrap with Bayesian Beta dropout.
+ alpha (float): Alpha parameter of the Beta distribution.
+ beta (float): Beta parameter of the Beta distribution.
+ """
super(NetworkBayesBeta, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
@@ -141,11 +222,20 @@ def __init__(
)
def mean_forward(
- self,
- data: torch.Tensor,
- n_iter: int,
+ self,
+ data: torch.Tensor,
+ n_iter: int,
):
+ """
+ Perform forward passes to estimate the mean and standard deviation of outputs.
+ Args:
+ data (torch.Tensor): Input tensor.
+ n_iter (int): Number of stochastic forward passes.
+
+ Returns:
+ torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs.
+ """
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
@@ -163,12 +253,26 @@ def mean_forward(
class NetworkBayesGauss(nn.Module):
+ """
+ Bayesian network with Gaussian noise.
+
+ Args:
+ model (nn.Module): The base model.
+ sigma (float): Standard deviation of the Gaussian noise.
+ """
+
def __init__(
- self,
- model: torch.nn.Module,
- sigma: float,
+ self,
+ model: torch.nn.Module,
+ sigma: float,
):
+ """
+ Initialize the NetworkBayesGauss with Gaussian noise.
+ Args:
+ model (nn.Module): The base model to wrap with Bayesian Gaussian noise.
+ sigma (float): Standard deviation of the Gaussian noise to apply.
+ """
super(NetworkBayesGauss, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
@@ -178,11 +282,20 @@ def __init__(
)
def mean_forward(
- self,
- data: torch.Tensor,
- n_iter: int,
+ self,
+ data: torch.Tensor,
+ n_iter: int,
):
+ """
+ Perform forward passes to estimate the mean and standard deviation of outputs.
+ Args:
+ data (torch.Tensor): Input tensor.
+ n_iter (int): Number of stochastic forward passes.
+
+ Returns:
+ torch.Tensor: A tensor containing the mean (dim=0) and standard deviation (dim=1) of outputs.
+ """
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
@@ -200,13 +313,28 @@ def mean_forward(
def create_dropout_bayesian_wrapper(
- model: torch.nn.Module,
- mode: Optional[str] = "basic",
- p: Optional[float] = None,
- a: Optional[float] = None,
- b: Optional[float] = None,
- sigma: Optional[float] = None,
+ model: torch.nn.Module,
+ mode: Optional[str] = "basic",
+ p: Optional[float] = None,
+ a: Optional[float] = None,
+ b: Optional[float] = None,
+ sigma: Optional[float] = None,
) -> torch.nn.Module:
+ """
+ Creates a Bayesian network with the specified dropout mode.
+
+ Args:
+ model (nn.Module): The base model.
+ mode (str): The dropout mode ("basic", "beta", "gauss").
+ p (Optional[float]): Dropout probability for "basic" mode.
+ a (Optional[float]): Alpha parameter for "beta" mode.
+ b (Optional[float]): Beta parameter for "beta" mode.
+ sigma (Optional[float]): Standard deviation for "gauss" mode.
+
+ Returns:
+ nn.Module: The Bayesian network.
+ """
+
if mode == "basic":
net = NetworkBayes(model, p)
@@ -216,4 +344,7 @@ def create_dropout_bayesian_wrapper(
elif mode == 'gauss':
net = NetworkBayesGauss(model, sigma)
+ else:
+ raise ValueError("Mode should be one of ('basic', 'beta', 'gauss').")
+
return net
From c0a88f42b4174003f1dc1a1e028fa32efd30da8d Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 03:31:35 +0300
Subject: [PATCH 06/16] feat: change tests structure
---
tests/test_bayesian/__init__.py | 0
tests/test_bayesian/test_bayes.py | 60 ++++++
tests/test_topology/__init__.py | 0
tests/test_topology/test_barcode_general.py | 57 ++++++
tests/test_topology/test_barcode_metrics.py | 50 +++++
tests/test_utils.py | 53 ------
tests/test_visualization/__init__.py | 0
.../test_check_random_input.py | 18 ++
tests/test_visualization/test_reduce_dim.py | 42 +++++
.../test_visualization/test_visualization.py | 119 ++++++++++++
tests/tests.py | 172 ------------------
tests/utils/__init__.py | 0
tests/utils/test_utils.py | 124 +++++++++++++
13 files changed, 470 insertions(+), 225 deletions(-)
create mode 100644 tests/test_bayesian/__init__.py
create mode 100644 tests/test_bayesian/test_bayes.py
create mode 100644 tests/test_topology/__init__.py
create mode 100644 tests/test_topology/test_barcode_general.py
create mode 100644 tests/test_topology/test_barcode_metrics.py
delete mode 100644 tests/test_utils.py
create mode 100644 tests/test_visualization/__init__.py
create mode 100644 tests/test_visualization/test_check_random_input.py
create mode 100644 tests/test_visualization/test_reduce_dim.py
create mode 100644 tests/test_visualization/test_visualization.py
delete mode 100644 tests/tests.py
create mode 100644 tests/utils/__init__.py
create mode 100644 tests/utils/test_utils.py
diff --git a/tests/test_bayesian/__init__.py b/tests/test_bayesian/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py
new file mode 100644
index 0000000..5b40f57
--- /dev/null
+++ b/tests/test_bayesian/test_bayes.py
@@ -0,0 +1,60 @@
+import torch
+
+import eXNN.bayes as bayes_api
+import tests.utils.test_utils as utils
+
+
+def _test_bayes_prediction(mode: str):
+ """
+ Helper function to test Bayesian wrappers with different configurations.
+
+ Args:
+ mode (str): Bayesian wrapper mode ("basic", "beta", "gauss").
+
+ Verifies:
+ - Output is a dictionary with keys "mean" and "std".
+ - Shapes of "mean" and "std" match the expected shape.
+ """
+ params = {
+ "basic": dict(mode="basic", p=0.5),
+ "beta": dict(mode="beta", a=0.9, b=0.2),
+ "gauss": dict(sigma=1e-2),
+ }
+
+ n, dim, data = utils.create_testing_data()
+ num_classes = 17
+ model = utils.create_testing_model(num_classes=num_classes)
+ n_iter = 7
+
+ if mode != 'gauss':
+ res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
+ else:
+ res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
+
+ utils.compare_values(dict, type(res), "Wrong result type")
+ utils.compare_values(2, len(res), "Wrong dictionary length")
+ utils.compare_values({"mean", "std"}, set(res.keys()), "Wrong dictionary keys")
+ utils.compare_values(torch.Size([n, num_classes]), res["mean"].shape, "Wrong mean shape")
+ utils.compare_values(torch.Size([n, num_classes]), res["std"].shape, "Wrong mean std")
+
+
+def test_basic_bayes_wrapper():
+ """
+ Test the basic Bayesian wrapper.
+ """
+ _test_bayes_prediction("basic")
+
+
+def test_beta_bayes_wrapper():
+ """
+ Test the beta Bayesian wrapper.
+ """
+ _test_bayes_prediction("beta")
+
+
+def test_gauss_bayes_wrapper():
+ """
+ Test the Gaussian Bayesian wrapper.
+ """
+ _test_bayes_prediction("gauss")
+
diff --git a/tests/test_topology/__init__.py b/tests/test_topology/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py
new file mode 100644
index 0000000..189e194
--- /dev/null
+++ b/tests/test_topology/test_barcode_general.py
@@ -0,0 +1,57 @@
+import matplotlib
+
+import eXNN.topology as topology_api
+import tests.utils.test_utils as utils
+
+
+def test_data_barcode():
+ """
+ Test generating a barcode from data.
+
+ Verifies:
+ - Output is a dictionary.
+ """
+ n, dim, data = utils.create_testing_data()
+ res = topology_api.get_data_barcode(data, "standard", "3")
+ utils.compare_values(dict, type(res), "Wrong result type")
+
+
+def test_nn_barcodes():
+ """
+ Test generating barcodes for a neural network.
+
+ Verifies:
+ - Output is a dictionary.
+ - Dictionary keys match the specified layers.
+ - Each dictionary value is a dictionary.
+ """
+ n, dim, data = utils.create_testing_data()
+ model = utils.create_testing_model()
+ layers = ["second_layer", "third_layer"]
+
+ res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
+
+ utils.compare_values(dict, type(res), "Wrong result type")
+ utils.compare_values(2, len(res), "Wrong dictionary length")
+ utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys")
+
+ for layer, barcode in res.items():
+ utils.compare_values(
+ dict,
+ type(barcode),
+ f"Wrong result type for key {layer}",
+ )
+
+
+def test_barcode_plot():
+ """
+ Test generating a barcode plot.
+
+ Verifies:
+ - Output is a matplotlib.figure.Figure.
+ """
+ n, dim, data = utils.create_testing_data()
+ barcode = topology_api.get_data_barcode(data, "standard", "3")
+ plot = topology_api.plot_barcode(barcode)
+ utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type")
+
diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py
new file mode 100644
index 0000000..9d8eaa2
--- /dev/null
+++ b/tests/test_topology/test_barcode_metrics.py
@@ -0,0 +1,50 @@
+import eXNN.topology as topology_api
+import tests.utils.test_utils as utils
+
+
+def test_barcode_evaluate_all_metrics():
+ """
+ Test evaluating all metrics for a barcode.
+
+ Verifies:
+ - Output is a dictionary.
+ - Dictionary keys match the expected metric names.
+ - Each metric value is a float.
+ """
+ n, dim, data = utils.create_testing_data()
+ barcode = topology_api.get_data_barcode(data, "standard", "3")
+ result = topology_api.evaluate_barcode(barcode)
+ utils.compare_values(dict, type(result), "Wrong result type")
+ all_metric_names = [
+ "h",
+ "max_length",
+ "mean_birth",
+ "mean_death",
+ "mean_length",
+ "median_length",
+ "normh",
+ "ratio_2_1",
+ "ratio_3_1",
+ "snr",
+ "stdev_birth",
+ "stdev_death",
+ "stdev_length",
+ "sum_length",
+ ]
+ utils.compare_values(all_metric_names, sorted(result.keys()))
+ for name, value in result.items():
+ utils.compare_values(float, type(value), f"Wrong result type for metric {name}")
+
+
+def test_barcode_evaluate_one_metric():
+ """
+ Test evaluating a single metric for a barcode.
+
+ Verifies:
+ - Output is a float.
+ """
+ n, dim, data = utils.create_testing_data()
+ barcode = topology_api.get_data_barcode(data, "standard", "3")
+ result = topology_api.evaluate_barcode(barcode, metric_name="mean_length")
+ utils.compare_values(float, type(result), "Wrong result type")
+
diff --git a/tests/test_utils.py b/tests/test_utils.py
deleted file mode 100644
index 08f08e8..0000000
--- a/tests/test_utils.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-
-
-def _form_message_header(message_header=None):
- return "Value mismatch" if message_header is None else message_header
-
-
-def compare_values(expected, got, message_header=None):
- assert (
- expected == got
- ), f"{_form_message_header(message_header)}: expected {expected}, got {got}"
-
-
-def create_testing_data():
- N = 20
- dim = 256
- data = torch.randn((N, dim))
- return N, dim, data
-
-
-def create_testing_model(num_classes=10):
- return nn.Sequential(
- OrderedDict(
- [
- ("first_layer", nn.Linear(256, 128)),
- ("second_layer", nn.Linear(128, 64)),
- ("third_layer", nn.Linear(64, num_classes)),
- ],
- ),
- )
-
-
-class ExtractTensor(nn.Module):
- def forward(self, x):
- tensor, _ = x
- x = x.to(torch.float32)
- return tensor[:, :]
-
-
-def create_testing_model_lstm(num_classes=10):
- return nn.Sequential(
- OrderedDict(
- [
- ('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
- ('extract', ExtractTensor()),
- ('second_layer', nn.Linear(128, 64)),
- ('third_layer', nn.Linear(64, num_classes)),
- ],
- ),
- )
diff --git a/tests/test_visualization/__init__.py b/tests/test_visualization/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_visualization/test_check_random_input.py b/tests/test_visualization/test_check_random_input.py
new file mode 100644
index 0000000..5bd6360
--- /dev/null
+++ b/tests/test_visualization/test_check_random_input.py
@@ -0,0 +1,18 @@
+import torch
+
+import eXNN.visualization as viz_api
+import tests.utils.test_utils as utils
+
+
+def test_check_random_input():
+ """
+ Test generating random input tensors with a specific shape.
+
+ Verifies:
+ - Output type is a torch.Tensor.
+ - Output shape matches the specified shape.
+ """
+ shape = [5, 17, 81, 37]
+ data = viz_api.get_random_input(shape)
+ utils.compare_values(type(data), torch.Tensor, "Wrong result type")
+ utils.compare_values(torch.Size(shape), data.shape, "Wrong result shape")
diff --git a/tests/test_visualization/test_reduce_dim.py b/tests/test_visualization/test_reduce_dim.py
new file mode 100644
index 0000000..a5b1723
--- /dev/null
+++ b/tests/test_visualization/test_reduce_dim.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+import eXNN.visualization as viz_api
+import tests.utils.test_utils as utils
+
+
+def _check_reduce_dim(mode):
+ """
+ Helper function to test dimensionality reduction with a given mode.
+
+ Args:
+ mode (str): Dimensionality reduction mode (e.g., "umap", "pca").
+
+ Verifies:
+ - Output type is a numpy.ndarray.
+ - Output shape is (n_samples, 2).
+ """
+ n, dim, data = utils.create_testing_data()
+ reduced_data = viz_api.reduce_dim(data, mode)
+ utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type")
+ utils.compare_values((n, 2), reduced_data.shape, "Wrong result shape")
+
+
+def test_reduce_dim_umap():
+ """
+ Test dimensionality reduction using UMAP.
+
+ Uses:
+ _check_reduce_dim with mode "umap".
+ """
+ _check_reduce_dim("umap")
+
+
+def test_reduce_dim_pca():
+ """
+ Test dimensionality reduction using PCA.
+
+ Uses:
+ _check_reduce_dim with mode "pca".
+ """
+ _check_reduce_dim("pca")
+
diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py
new file mode 100644
index 0000000..2471c0c
--- /dev/null
+++ b/tests/test_visualization/test_visualization.py
@@ -0,0 +1,119 @@
+import matplotlib
+import plotly
+import torch
+
+import eXNN.visualization as viz_api
+import tests.utils.test_utils as utils
+
+
+def test_visualization():
+ """
+ Test visualization of layer manifolds using UMAP.
+
+ Verifies:
+ - Output is a dictionary.
+ - Dictionary has the correct keys corresponding to the input and specified layers.
+ - Each dictionary value is a matplotlib.figure.Figure.
+ """
+
+ n, dim, data = utils.create_testing_data()
+ model = utils.create_testing_model()
+ layers = ["second_layer", "third_layer"]
+ res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
+
+ utils.compare_values(dict, type(res), "Wrong result type")
+ utils.compare_values(3, len(res), "Wrong dictionary length")
+ utils.compare_values(
+ set(["input"] + layers),
+ set(res.keys()),
+ "Wrong dictionary keys",
+ )
+
+ for key, plot in res.items():
+ utils.compare_values(
+ matplotlib.figure.Figure,
+ type(plot),
+ f"Wrong value type for key {key}",
+ )
+
+
+def test_embed_visualization():
+ """
+ Test visualization of recurrent layer manifolds with embeddings.
+
+ Verifies:
+ - Output is a dictionary.
+ - Dictionary keys match the specified layers.
+ - Each dictionary value is a plotly.graph_objs.Figure.
+ """
+ data = torch.randn((20, 1, 256))
+ labels = torch.randn(20)
+ model = utils.create_testing_model_lstm()
+ layers = ["second_layer", "third_layer"]
+
+ res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels)
+
+ utils.compare_values(dict, type(res), "Wrong result type")
+ utils.compare_values(2, len(res), "Wrong dictionary length")
+ utils.compare_values(
+ set(layers),
+ set(res.keys()),
+ "Wrong dictionary keys",
+ )
+ for key, plot in res.items():
+ utils.compare_values(
+ plotly.graph_objs.Figure,
+ type(plot),
+ f"Wrong value type for key {key}",
+ )
+
+
+def test_all_visualizations():
+ """
+ Test all visualizations (layer manifolds and recurrent layer manifolds).
+
+ Verifies:
+ - Visualization of layer manifolds works with UMAP for regular layers.
+ - Visualization of recurrent layer manifolds works with embeddings.
+ - Correct output types for both types of plots (matplotlib and Plotly).
+ """
+ # Test for layer manifolds visualization using UMAP
+ n, dim, data = utils.create_testing_data()
+ model = utils.create_testing_model()
+ layers = ["second_layer", "third_layer"]
+ res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
+
+ utils.compare_values(dict, type(res), "Wrong result type for layer manifolds")
+ utils.compare_values(3, len(res), "Wrong dictionary length for layer manifolds")
+ utils.compare_values(
+ set(["input"] + layers),
+ set(res.keys()),
+ "Wrong dictionary keys for layer manifolds",
+ )
+
+ for key, plot in res.items():
+ utils.compare_values(
+ matplotlib.figure.Figure,
+ type(plot),
+ f"Wrong value type for key {key} in layer manifolds",
+ )
+
+ # Test for recurrent layer manifolds visualization using embeddings
+ data = torch.randn((20, 1, 256))
+ labels = torch.randn(20)
+ model = utils.create_testing_model_lstm()
+ res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels)
+
+ utils.compare_values(dict, type(res), "Wrong result type for recurrent layer manifolds")
+ utils.compare_values(2, len(res), "Wrong dictionary length for recurrent layer manifolds")
+ utils.compare_values(
+ set(layers),
+ set(res.keys()),
+ "Wrong dictionary keys for recurrent layer manifolds",
+ )
+ for key, plot in res.items():
+ utils.compare_values(
+ plotly.graph_objs.Figure,
+ type(plot),
+ f"Wrong value type for key {key} in recurrent layer manifolds",
+ )
diff --git a/tests/tests.py b/tests/tests.py
deleted file mode 100644
index 43edb93..0000000
--- a/tests/tests.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import matplotlib
-import numpy as np
-import plotly
-import torch
-
-import eXNN.bayes as bayes_api
-import eXNN.topology as topology_api
-import eXNN.visualization as viz_api
-import tests.test_utils as utils
-
-
-def test_check_random_input():
- shape = [5, 17, 81, 37]
- data = viz_api.get_random_input(shape)
- utils.compare_values(type(data), torch.Tensor, "Wrong result type")
- utils.compare_values(torch.Size(shape), data.shape, "Wrong result shape")
-
-
-def _check_reduce_dim(mode):
- N, dim, data = utils.create_testing_data()
- reduced_data = viz_api.reduce_dim(data, mode)
- utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type")
- utils.compare_values((N, 2), reduced_data.shape, "Wrong result shape")
-
-
-def test_reduce_dim_umap():
- _check_reduce_dim("umap")
-
-
-def test_reduce_dim_pca():
- _check_reduce_dim("pca")
-
-
-def test_visualization():
- N, dim, data = utils.create_testing_data()
- model = utils.create_testing_model()
- layers = ["second_layer", "third_layer"]
- res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
-
- utils.compare_values(dict, type(res), "Wrong result type")
- utils.compare_values(3, len(res), "Wrong dictionary length")
- utils.compare_values(
- set(["input"] + layers),
- set(res.keys()),
- "Wrong dictionary keys",
- )
- for key, plot in res.items():
- utils.compare_values(
- matplotlib.figure.Figure,
- type(plot),
- f"Wrong value type for key {key}",
- )
-
-
-def test_embed_visualization():
- data = torch.randn((20, 1, 256))
- labels = torch.randn((20))
- model = utils.create_testing_model_lstm()
- layers = ["second_layer", "third_layer"]
- res = viz_api.visualize_recurrent_layer_manifolds(model, "umap",
- data, layers=layers, labels=labels)
- utils.compare_values(dict, type(res), "Wrong result type")
- utils.compare_values(2, len(res), "Wrong dictionary length")
- utils.compare_values(
- set(layers),
- set(res.keys()),
- "Wrong dictionary keys",
- )
- for key, plot in res.items():
- utils.compare_values(
- plotly.graph_objs.Figure,
- type(plot),
- f"Wrong value type for key {key}",
- )
-
-
-def _test_bayes_prediction(mode: str):
- params = {
- "basic": dict(mode="basic", p=0.5),
- "beta": dict(mode="beta", a=0.9, b=0.2),
- "gauss": dict(sigma=1e-2),
- }
-
- N, dim, data = utils.create_testing_data()
- num_classes = 17
- model = utils.create_testing_model(num_classes=num_classes)
- n_iter = 7
- if mode != 'gauss':
- res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
- else:
- res = bayes_api.GaussianBayesianWrapper(model, **(params[mode])).predict(data,
- n_iter=n_iter)
-
- utils.compare_values(dict, type(res), "Wrong result type")
- utils.compare_values(2, len(res), "Wrong dictionary length")
- utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys")
- utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape")
- utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std")
-
-
-def test_basic_bayes_wrapper():
- _test_bayes_prediction("basic")
-
-
-def test_beta_bayes_wrapper():
- _test_bayes_prediction("beta")
-
-
-def test_gauss_bayes_wrapper():
- _test_bayes_prediction("gauss")
-
-
-def test_data_barcode():
- N, dim, data = utils.create_testing_data()
- res = topology_api.get_data_barcode(data, "standard", "3")
- utils.compare_values(dict, type(res), "Wrong result type")
-
-
-def test_nn_barcodes():
- N, dim, data = utils.create_testing_data()
- model = utils.create_testing_model()
- layers = ["second_layer", "third_layer"]
- res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
- utils.compare_values(dict, type(res), "Wrong result type")
- utils.compare_values(2, len(res), "Wrong dictionary length")
- utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys")
- for layer, barcode in res.items():
- utils.compare_values(
- dict,
- type(barcode),
- f"Wrong result type for key {layer}",
- )
-
-
-def test_barcode_plot():
- N, dim, data = utils.create_testing_data()
- barcode = topology_api.get_data_barcode(data, "standard", "3")
- plot = topology_api.plot_barcode(barcode)
- utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type")
-
-
-def test_barcode_evaluate_all_metrics():
- N, dim, data = utils.create_testing_data()
- barcode = topology_api.get_data_barcode(data, "standard", "3")
- result = topology_api.evaluate_barcode(barcode)
- utils.compare_values(dict, type(result), "Wrong result type")
- all_metric_names = [
- "h",
- "max_length",
- "mean_birth",
- "mean_death",
- "mean_length",
- "median_length",
- "normh",
- "ratio_2_1",
- "ratio_3_1",
- "snr",
- "stdev_birth",
- "stdev_death",
- "stdev_length",
- "sum_length",
- ]
- utils.compare_values(all_metric_names, sorted(result.keys()))
- for name, value in result.items():
- utils.compare_values(float, type(value), f"Wrong result type for metric {name}")
-
-
-def test_barcode_evaluate_one_metric():
- N, dim, data = utils.create_testing_data()
- barcode = topology_api.get_data_barcode(data, "standard", "3")
- result = topology_api.evaluate_barcode(barcode, metric_name="mean_length")
- utils.compare_values(float, type(result), "Wrong result type")
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py
new file mode 100644
index 0000000..47c85f7
--- /dev/null
+++ b/tests/utils/test_utils.py
@@ -0,0 +1,124 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+
+def _form_message_header(message_header=None):
+ """
+ Forms the header for assertion messages.
+
+ Args:
+ message_header (str, optional): Custom message header. Defaults to None.
+
+ Returns:
+ str: The custom header if provided, otherwise "Value mismatch".
+ """
+ return "Value mismatch" if message_header is None else message_header
+
+
+def compare_values(expected, got, message_header=None):
+ """
+ Compares two values and raises an assertion error if they do not match.
+
+ Args:
+ expected: The expected value.
+ got: The value to compare against the expected value.
+ message_header (str, optional): Custom header for the assertion message. Defaults to None.
+
+ Raises:
+ AssertionError: If the values do not match, an error with the message header is raised.
+ """
+ assert (
+ expected == got
+ ), f"{_form_message_header(message_header)}: expected {expected}, got {got}"
+
+
+def create_testing_data():
+ """
+ Creates synthetic testing data.
+
+ Returns:
+ tuple: A tuple containing:
+ - N (int): Number of samples.
+ - dim (int): Dimensionality of each sample.
+ - data (torch.Tensor): Randomly generated data tensor of shape (N, dim).
+ """
+ n = 20
+ dim = 256
+ data = torch.randn((n, dim))
+ return n, dim, data
+
+
+def create_testing_model(num_classes=10):
+ """
+ Creates a simple feedforward neural network for testing.
+
+ Args:
+ num_classes (int, optional): Number of output classes. Defaults to 10.
+
+ Returns:
+ nn.Sequential: A sequential model with three layers:
+ - Linear layer (input_dim=256, output_dim=128).
+ - Linear layer (input_dim=128, output_dim=64).
+ - Linear layer (input_dim=64, output_dim=num_classes).
+ """
+ return nn.Sequential(
+ OrderedDict(
+ [
+ ("first_layer", nn.Linear(256, 128)),
+ ("second_layer", nn.Linear(128, 64)),
+ ("third_layer", nn.Linear(64, num_classes)),
+ ],
+ ),
+ )
+
+
+class ExtractTensor(nn.Module):
+ """
+ A custom PyTorch module to extract and process tensors.
+
+ This module extracts the first tensor from a tuple, converts it to
+ float32, and returns its values.
+ """
+
+ @staticmethod
+ def forward(x):
+ """
+ Forward pass for tensor extraction and processing.
+
+ Args:
+ x (tuple): Input tuple where the first element is the tensor to be processed.
+
+ Returns:
+ torch.Tensor: Processed tensor (converted to float32).
+ """
+ tensor, _ = x
+ tensor = tensor.to(torch.float32)
+ return tensor[:, :]
+
+
+def create_testing_model_lstm(num_classes=10):
+ """
+ Creates a recurrent neural network with LSTM layers for testing.
+
+ Args:
+ num_classes (int, optional): Number of output classes. Defaults to 10.
+
+ Returns:
+ nn.Sequential: A sequential model with the following layers:
+ - LSTM layer (input_dim=256, hidden_dim=128, num_layers=1).
+ - ExtractTensor layer to process LSTM output.
+ - Linear layer (input_dim=128, output_dim=64).
+ - Linear layer (input_dim=64, output_dim=num_classes).
+ """
+ return nn.Sequential(
+ OrderedDict(
+ [
+ ('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
+ ('extract', ExtractTensor()),
+ ('second_layer', nn.Linear(128, 64)),
+ ('third_layer', nn.Linear(64, num_classes)),
+ ],
+ ),
+ )
From 0f5ba794045554cfd2f578426acdfb8a37df05d6 Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 03:34:23 +0300
Subject: [PATCH 07/16] feat: add new tests and remove extra one for
visualization
---
tests/test_bayesian/test_bayes.py | 9 ++++
tests/test_topology/test_barcode_general.py | 26 ++++++++++
tests/test_topology/test_barcode_metrics.py | 37 ++++++++++++++
tests/test_visualization/test_reduce_dim.py | 13 +++++
.../test_visualization/test_visualization.py | 51 -------------------
5 files changed, 85 insertions(+), 51 deletions(-)
diff --git a/tests/test_bayesian/test_bayes.py b/tests/test_bayesian/test_bayes.py
index 5b40f57..290fb7c 100644
--- a/tests/test_bayesian/test_bayes.py
+++ b/tests/test_bayesian/test_bayes.py
@@ -58,3 +58,12 @@ def test_gauss_bayes_wrapper():
"""
_test_bayes_prediction("gauss")
+
+def test_all_bayes_wrappers():
+ """
+ Test all Bayesian wrappers (basic, beta, and gauss) with a single general test.
+ """
+ modes = ["basic", "beta", "gauss"]
+
+ for mode in modes:
+ _test_bayes_prediction(mode)
diff --git a/tests/test_topology/test_barcode_general.py b/tests/test_topology/test_barcode_general.py
index 189e194..174ce28 100644
--- a/tests/test_topology/test_barcode_general.py
+++ b/tests/test_topology/test_barcode_general.py
@@ -55,3 +55,29 @@ def test_barcode_plot():
plot = topology_api.plot_barcode(barcode)
utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type")
+
+def test_all_barcodes():
+ """
+ Test all barcode-related functions (data barcode, NN barcodes, and barcode plot) together.
+
+ Verifies:
+ - Data barcode is generated correctly.
+ - Neural network barcodes are generated correctly for specified layers.
+ - Barcode plot is generated correctly.
+ """
+ # Test data barcode
+ n, dim, data = utils.create_testing_data()
+ res = topology_api.get_data_barcode(data, "standard", "3")
+ utils.compare_values(dict, type(res), "Wrong result type for data barcode")
+
+ # Test NN barcodes
+ model = utils.create_testing_model()
+ layers = ["second_layer", "third_layer"]
+ nn_barcodes = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
+ utils.compare_values(dict, type(nn_barcodes), "Wrong result type for NN barcodes")
+ utils.compare_values(2, len(nn_barcodes), "Wrong dictionary length for NN barcodes")
+ utils.compare_values(set(layers), set(nn_barcodes.keys()), "Wrong dictionary keys for NN barcodes")
+
+ # Test barcode plot
+ barcode_plot = topology_api.plot_barcode(res)
+ utils.compare_values(matplotlib.figure.Figure, type(barcode_plot), "Wrong result type for barcode plot")
diff --git a/tests/test_topology/test_barcode_metrics.py b/tests/test_topology/test_barcode_metrics.py
index 9d8eaa2..53a3a29 100644
--- a/tests/test_topology/test_barcode_metrics.py
+++ b/tests/test_topology/test_barcode_metrics.py
@@ -48,3 +48,40 @@ def test_barcode_evaluate_one_metric():
result = topology_api.evaluate_barcode(barcode, metric_name="mean_length")
utils.compare_values(float, type(result), "Wrong result type")
+
+def test_barcode_evaluate_all_metrics_and_individual():
+ """
+ Test evaluating all metrics for a barcode and individual metrics evaluation.
+
+ Verifies:
+ - Output for evaluating all metrics is a dictionary.
+ - Dictionary keys match the expected metric names.
+ - Each metric value is a float.
+ - Individual metrics can be correctly evaluated.
+ """
+ # Test for all metrics
+ n, dim, data = utils.create_testing_data()
+ barcode = topology_api.get_data_barcode(data, "standard", "3")
+ result = topology_api.evaluate_barcode(barcode)
+
+ # Check that the result is a dictionary
+ utils.compare_values(dict, type(result), "Wrong result type")
+
+ # List of expected metric names
+ all_metric_names = [
+ "h", "max_length", "mean_birth", "mean_death", "mean_length",
+ "median_length", "normh", "ratio_2_1", "ratio_3_1", "snr",
+ "stdev_birth", "stdev_death", "stdev_length", "sum_length"
+ ]
+
+ # Check that the dictionary keys match the expected metric names
+ utils.compare_values(all_metric_names, sorted(result.keys()), "Wrong dictionary keys")
+
+ # Ensure all metric values are floats
+ for name, value in result.items():
+ utils.compare_values(float, type(value), f"Wrong result type for metric {name}")
+
+ # Test for evaluating individual metrics
+ for metric_name in all_metric_names:
+ individual_result = topology_api.evaluate_barcode(barcode, metric_name=metric_name)
+ utils.compare_values(float, type(individual_result), f"Wrong result type for individual metric {metric_name}")
diff --git a/tests/test_visualization/test_reduce_dim.py b/tests/test_visualization/test_reduce_dim.py
index a5b1723..8ed32b9 100644
--- a/tests/test_visualization/test_reduce_dim.py
+++ b/tests/test_visualization/test_reduce_dim.py
@@ -40,3 +40,16 @@ def test_reduce_dim_pca():
"""
_check_reduce_dim("pca")
+
+def test_all_reduce_dim_methods():
+ """
+ Test dimensionality reduction using all available methods (e.g., UMAP, PCA).
+
+ Verifies:
+ - Dimensionality reduction works for each method.
+ - Output is of type numpy.ndarray and shape (n_samples, 2).
+ """
+ modes = ["umap", "pca"]
+
+ for mode in modes:
+ _check_reduce_dim(mode)
diff --git a/tests/test_visualization/test_visualization.py b/tests/test_visualization/test_visualization.py
index 2471c0c..31a6206 100644
--- a/tests/test_visualization/test_visualization.py
+++ b/tests/test_visualization/test_visualization.py
@@ -66,54 +66,3 @@ def test_embed_visualization():
type(plot),
f"Wrong value type for key {key}",
)
-
-
-def test_all_visualizations():
- """
- Test all visualizations (layer manifolds and recurrent layer manifolds).
-
- Verifies:
- - Visualization of layer manifolds works with UMAP for regular layers.
- - Visualization of recurrent layer manifolds works with embeddings.
- - Correct output types for both types of plots (matplotlib and Plotly).
- """
- # Test for layer manifolds visualization using UMAP
- n, dim, data = utils.create_testing_data()
- model = utils.create_testing_model()
- layers = ["second_layer", "third_layer"]
- res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
-
- utils.compare_values(dict, type(res), "Wrong result type for layer manifolds")
- utils.compare_values(3, len(res), "Wrong dictionary length for layer manifolds")
- utils.compare_values(
- set(["input"] + layers),
- set(res.keys()),
- "Wrong dictionary keys for layer manifolds",
- )
-
- for key, plot in res.items():
- utils.compare_values(
- matplotlib.figure.Figure,
- type(plot),
- f"Wrong value type for key {key} in layer manifolds",
- )
-
- # Test for recurrent layer manifolds visualization using embeddings
- data = torch.randn((20, 1, 256))
- labels = torch.randn(20)
- model = utils.create_testing_model_lstm()
- res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels)
-
- utils.compare_values(dict, type(res), "Wrong result type for recurrent layer manifolds")
- utils.compare_values(2, len(res), "Wrong dictionary length for recurrent layer manifolds")
- utils.compare_values(
- set(layers),
- set(res.keys()),
- "Wrong dictionary keys for recurrent layer manifolds",
- )
- for key, plot in res.items():
- utils.compare_values(
- plotly.graph_objs.Figure,
- type(plot),
- f"Wrong value type for key {key} in recurrent layer manifolds",
- )
From 842803ba82cec2abbc23733a8dc53318c36f9cfb Mon Sep 17 00:00:00 2001
From: asmoorr <90kidex90@gmail.com>
Date: Sun, 22 Dec 2024 03:41:30 +0300
Subject: [PATCH 08/16] feat: change some filenames
---
.idea/.gitignore | 3 +++
.idea/eXplain-NNs.iml | 15 +++++++++++++++
.idea/inspectionProfiles/Project_Default.xml | 6 ++++++
.../inspectionProfiles/profiles_settings.xml | 6 ++++++
.idea/misc.xml | 7 +++++++
.idea/modules.xml | 8 ++++++++
.idea/vcs.xml | 6 ++++++
README_eng.md | 2 +-
README.md => README_ru.md | 0
...ors.svg => strong_ai_in_industry_logo.svg} | 0
eXNN/bayes/api.py | 19 +++++++++----------
11 files changed, 61 insertions(+), 11 deletions(-)
create mode 100644 .idea/.gitignore
create mode 100644 .idea/eXplain-NNs.iml
create mode 100644 .idea/inspectionProfiles/Project_Default.xml
create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 .idea/vcs.xml
rename README.md => README_ru.md (100%)
rename docs/{AIM-Strong_Sign_Norm-01_Colors.svg => strong_ai_in_industry_logo.svg} (100%)
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/eXplain-NNs.iml b/.idea/eXplain-NNs.iml
new file mode 100644
index 0000000..28bc459
--- /dev/null
+++ b/.idea/eXplain-NNs.iml
@@ -0,0 +1,15 @@
+
+
- -
- -[![SAI](https://github.com/ITMO-NSS-team/open-source-ops/blob/master/badges/SAI_badge_flat.svg)](https://sai.itmo.ru/) -[![ITMO](https://github.com/ITMO-NSS-team/open-source-ops/blob/master/badges/ITMO_badge_flat_rus.svg)](https://en.itmo.ru/en/) - -[![Documentation](https://github.com/aimclub/eXplain-NNs/actions/workflows/pages/pages-build-deployment/badge.svg)](https://med-ai-lab.github.io/eXplain-NNs-documentation/) -[![license](https://img.shields.io/github/license/aimclub/eXplain-NNs)](https://github.com/aimclub/eXplain-NNs/blob/main/LICENSE) -[![Eng](https://img.shields.io/badge/lang-en-red.svg)](/README) -[![Mirror](https://img.shields.io/badge/mirror-GitLab-orange)](https://gitlab.actcognitive.org/itmo-sai-code/eXplain-NNs) - -# eXplain-NNs - -Этот репозиторий содержит библиотеку eXplain-NNs — библиотеку с открытым исходным кодом с методами объяснимого ИИ (XAI) -для анализа нейронных сетей. Эта библиотека предоставляет несколько методов XAI для анализа латентных пространств и -оценки неопределенности. - -## Описание проекта - -### Методы - -Методы XAI, реализованные в библиотеке - -1. визуализация латентных пространств -2. гомологический анализ латентных пространств -3. оценка неопределенности с помощью байесианизации - -Таким образом, по сравнению с другими библиотеками объяснимого ИИ библиотека eXplain-NNs: - -* Обеспечивает анализ гомологий латентных пространств -* Внедряет новый метод оценки неопределенности с помощью байесианизации XAI для анализа латентных пространств и оценки - неопределенности. - -Детали [реализации методов](/docs/methods.md). - -### Требования к данным - -* Библиотека поддерживает только модели, которые являются: - * полносвязные или конволюционные - * разработаны для задачи классификации - -## Установка - -Требования: Python 3.8 - -1. [optional] создайте среду окружения Python, e.g. - ``` - $ conda create -n eXNN python=3.8 - $ conda activate eXNN - ``` -2. установите зависимости из [requirements.txt](/requirements.txt) - ``` - $ pip install -r requirements.txt - ``` -3. установите библиотеку как пакет - ``` - $ python -m pip install git+ssh://git@github.com/Med-AI-Lab/eXplain-NNs - ``` - -Видео с процессом установки можно -посмотреть [здесь](https://drive.google.com/file/d/1Sv8UiRwWfMLJ0kOSYHB_PgILHzNcqfs0/view?usp=sharing). - -## Development - -Требования: Python 3.8 - -1. [optional] создайте среды окружения Python, e.g. - ``` - $ conda create -n eXNN python=3.8 - $ conda activate eXNN - ``` -2. клонируйте репозиторий и установите зависимости - ``` - $ git clone git@github.com:Med-AI-Lab/eXplain-NNs.git - $ cd eXplain-NNs - $ pip install -r requirements.txt - ``` -3. запуск тестов - ``` - $ pytest tests - ``` -4. приведение стиля кода в соответствие с PEP8 автоматически - ``` - $ make format - ``` -5. проверка стиля кода на соответствие с PEP8 - ``` - $ make check - ``` -6. создание PyPi пакета локально - ``` - $ python3 -m pip install --upgrade build - $ python3 -m build - ``` - -## Документация - -[Документация](https://med-ai-lab.github.io/eXplain-NNs-documentation/) - -[API](https://med-ai-lab.github.io/eXplain-NNs-documentation/api_docs/eXNN.html) - -## Примеры и туториалы - -Мы предоставляем примеры разного уровня сложности: - -* [минимальные] минималистичные примеры, представляющие наш API -* [базовые] применение eXNN для простых задач, таких как классификация MNIST -* [сценарии использования] демонстрация использования eXplain-NN для решения различных проблем, возникающих в - промышленных задачах. - -### Минимальные - -Этот колаб содержит минималистическую демонстрацию нашего API на фиктивных объектах: - -[![minimal](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1lOiB50LppDiiRHTv184JMuQ2IvZ4I4rp?usp=sharing) - -### Базовые - -Вот колабы, демонстрирующие, как работать с разными модулями нашего API на простых задачах: - -| Colab Link | Module | -|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| [![bayes](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ayd0IronxUIfnbAmWQLHiILG2qtBBpF4?usp=sharing) | bayes | -| [![topology](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1T5ENfNaCIRI61LM2ZhtU8lfmvRmlfiEo?usp=sharing) | topology | -| [![visualization](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LJVdWTv-wcASSMX4is_E15TR7XJsT7W3?usp=sharing) | visualization | - -### Сценарии использования - -В этом блоке представлены примеры использования eXplain-NN для решения различных вариантов использования в промышленных -задачах. Для демонстрационных целей используются 3 задачи: - -* [спутник] классификация ландшафтов по спутниковым снимкам. -* [электроника] классификация электронных компонентов и устройств -* [ЭКГ] диагностика ЭКГ - -| Colab Link | Task | Use Case | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|-----------------------------------------------------------| -| [![CNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12ZJigH-0geGTefNXnCM5dQ71d4tqlf6L?usp=sharing) | спутник | Визуализация изменения многообразия данных от слоя к слою | -| [![adv](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1n50WUu2ZKZ6nrT9DuFD3q87m3yZvxkwm?usp=sharing) | спутник | Детекция adversarial данных | -| [![generalize](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mG-VrP7J7OoCvIQDl7n5YWEIdyfFg_0I?usp=sharing) | электроника | Оценка обобщающей способности нейронной сети | -| [![RNN_viz](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aAtqxQLcOsSJJumfsmS9HGLgHrOFHlfk?usp=sharing) | ЭКГ | Визуализация изменения многообразия данных от слоя к слою | - -## Как помочь проекту - -[Инструкции](/docs/contribution.md). - -## Организационное - -### Аффилиация - -[Университет ИТМО](https://en.itmo.ru/). - -### Поддержка - -Исследование проводится при -поддержке [Исследовательского центра сильного искусственного интеллекта в промышленности](