Skip to content

Commit

Permalink
Merge pull request #71 from Asmoorr/main
Browse files Browse the repository at this point in the history
Changes during Github review
  • Loading branch information
tanyapole authored Jan 14, 2025
2 parents 599e626 + c57bd3b commit e36fff3
Show file tree
Hide file tree
Showing 18 changed files with 1,078 additions and 207 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

.idea
/.idea
.idea/
File renamed without changes
19 changes: 9 additions & 10 deletions eXNN/bayes/api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from typing import Dict, Optional

import torch
import torch.optim

from eXNN.bayes.wrapper import create_dropout_bayesian_wrapper


class DropoutBayesianWrapper:
def __init__(
self,
model: torch.nn.Module,
mode: str,
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
self,
model: torch.nn.Module,
mode: str,
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
):
"""Class representing bayesian equivalent of a neural network.
Expand Down Expand Up @@ -49,9 +48,9 @@ def predict(self, data, n_iter) -> Dict[str, torch.Tensor]:

class GaussianBayesianWrapper:
def __init__(
self,
model: torch.nn.Module,
sigma: float,
self,
model: torch.nn.Module,
sigma: float,
):
"""Class representing bayesian equivalent of a neural network.
Expand Down
208 changes: 172 additions & 36 deletions eXNN/bayes/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as functional
from torch.distributions import Beta


class ModuleBayesianWrapper(nn.Module):
"""
A wrapper for neural network layers to apply Bayesian-style dropout or noise during training.
Args:
layer (nn.Module): The layer to wrap (e.g., nn.Linear, nn.Conv2d).
p (Optional[float]): Dropout probability for simple dropout.
Mutually exclusive with `a`, `b`, and `sigma`.
a (Optional[float]): Alpha parameter for Beta distribution dropout. Used with `b`.
b (Optional[float]): Beta parameter for Beta distribution dropout. Used with `a`.
sigma (Optional[float]): Standard deviation for Gaussian noise.
Mutually exclusive with `p`, `a`, and `b`.
"""

def __init__(
self,
layer: nn.Module,
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None,
self,
layer: nn.Module,
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None,
):
super(ModuleBayesianWrapper, self).__init__()

Expand All @@ -36,6 +49,16 @@ def __init__(
self.p, self.a, self.b, self.sigma = p, a, b, sigma

def augment_weights(self, weights, bias):
"""
Apply the specified noise or dropout to the weights and bias.
Args:
weights (torch.Tensor): The weights of the layer.
bias (torch.Tensor): The bias of the layer (can be None).
Returns:
Tuple[torch.Tensor, torch.Tensor]: The augmented weights and bias.
"""

# Check if dropout is chosen
if (self.p is not None) or (self.a is not None and self.b is not None):
Expand All @@ -45,10 +68,10 @@ def augment_weights(self, weights, bias):
else:
p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample()

weights = F.dropout(weights, p, training=True)
weights = functional.dropout(weights, p, training=True)
if bias is not None:
# In layers we sometimes have the ability to set bias to None
bias = F.dropout(bias, p, training=True)
bias = functional.dropout(bias, p, training=True)

else:
# If gauss is chosen, then apply it
Expand All @@ -60,18 +83,39 @@ def augment_weights(self, weights, bias):
return weights, bias

def forward(self, x):
"""
Forward pass through the layer with augmented weights.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""

weight, bias = self.augment_weights(self.layer.weight, self.layer.bias)

if isinstance(self.layer, nn.Linear):
return F.linear(x, weight, bias)
return functional.linear(x, weight, bias)
elif type(self.layer) in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
return self.layer._conv_forward(x, weight, bias)
else:
return self.layer(x)


def replace_modules_with_wrapper(model, wrapper_module, params):
"""
Recursively replaces layers in a model with a Bayesian wrapper.
Args:
model (nn.Module): The model containing layers to replace.
wrapper_module (type): The wrapper class.
params (dict): Parameters for the wrapper.
Returns:
nn.Module: The model with wrapped layers.
"""

if type(model) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]:
return wrapper_module(model, **params)

Expand All @@ -88,12 +132,26 @@ def replace_modules_with_wrapper(model, wrapper_module, params):


class NetworkBayes(nn.Module):
"""
Bayesian network with standard dropout.
Args:
model (nn.Module): The base model.
dropout_p (float): Dropout probability.
"""

def __init__(
self,
model: nn.Module,
dropout_p: float,
self,
model: nn.Module,
dropout_p: float,
):
"""
Initialize the NetworkBayes with standard dropout.
Args:
model (nn.Module): The base model to wrap with Bayesian dropout.
dropout_p (float): Dropout probability for the Bayesian wrapper.
"""
super(NetworkBayes, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
Expand All @@ -103,11 +161,21 @@ def __init__(
)

def mean_forward(
self,
data: torch.Tensor,
n_iter: int,
self,
data: torch.Tensor,
n_iter: int,
):
"""
Perform forward passes to estimate the mean and standard deviation of outputs.
Args:
data (torch.Tensor): Input tensor.
n_iter (int): Number of stochastic forward passes.
Returns:
torch.Tensor: A tensor containing the mean (dim=0) and
standard deviation (dim=1) of outputs.
"""
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
Expand All @@ -125,13 +193,29 @@ def mean_forward(

# calculate mean and std after applying bayesian with beta distribution
class NetworkBayesBeta(nn.Module):
"""
Bayesian network with Beta distribution dropout.
Args:
model (nn.Module): The base model.
alpha (float): Alpha parameter for the Beta distribution.
beta (float): Beta parameter for the Beta distribution.
"""

def __init__(
self,
model: torch.nn.Module,
alpha: float,
beta: float,
self,
model: torch.nn.Module,
alpha: float,
beta: float,
):

"""
Initialize the NetworkBayesBeta with Beta distribution dropout.
Args:
model (nn.Module): The base model to wrap with Bayesian Beta dropout.
alpha (float): Alpha parameter of the Beta distribution.
beta (float): Beta parameter of the Beta distribution.
"""
super(NetworkBayesBeta, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
Expand All @@ -141,11 +225,21 @@ def __init__(
)

def mean_forward(
self,
data: torch.Tensor,
n_iter: int,
self,
data: torch.Tensor,
n_iter: int,
):
"""
Perform forward passes to estimate the mean and standard deviation of outputs.
Args:
data (torch.Tensor): Input tensor.
n_iter (int): Number of stochastic forward passes.
Returns:
torch.Tensor: A tensor containing the mean (dim=0) and
standard deviation (dim=1) of outputs.
"""
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
Expand All @@ -163,12 +257,26 @@ def mean_forward(


class NetworkBayesGauss(nn.Module):
"""
Bayesian network with Gaussian noise.
Args:
model (nn.Module): The base model.
sigma (float): Standard deviation of the Gaussian noise.
"""

def __init__(
self,
model: torch.nn.Module,
sigma: float,
self,
model: torch.nn.Module,
sigma: float,
):
"""
Initialize the NetworkBayesGauss with Gaussian noise.
Args:
model (nn.Module): The base model to wrap with Bayesian Gaussian noise.
sigma (float): Standard deviation of the Gaussian noise to apply.
"""
super(NetworkBayesGauss, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(
Expand All @@ -178,11 +286,21 @@ def __init__(
)

def mean_forward(
self,
data: torch.Tensor,
n_iter: int,
self,
data: torch.Tensor,
n_iter: int,
):
"""
Perform forward passes to estimate the mean and standard deviation of outputs.
Args:
data (torch.Tensor): Input tensor.
n_iter (int): Number of stochastic forward passes.
Returns:
torch.Tensor: A tensor containing the mean (dim=0) and
standard deviation (dim=1) of outputs.
"""
results = []
for _ in range(n_iter):
results.append(self.model.forward(data))
Expand All @@ -200,13 +318,28 @@ def mean_forward(


def create_dropout_bayesian_wrapper(
model: torch.nn.Module,
mode: Optional[str] = "basic",
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None,
model: torch.nn.Module,
mode: Optional[str] = "basic",
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None,
) -> torch.nn.Module:
"""
Creates a Bayesian network with the specified dropout mode.
Args:
model (nn.Module): The base model.
mode (str): The dropout mode ("basic", "beta", "gauss").
p (Optional[float]): Dropout probability for "basic" mode.
a (Optional[float]): Alpha parameter for "beta" mode.
b (Optional[float]): Beta parameter for "beta" mode.
sigma (Optional[float]): Standard deviation for "gauss" mode.
Returns:
nn.Module: The Bayesian network.
"""

if mode == "basic":
net = NetworkBayes(model, p)

Expand All @@ -216,4 +349,7 @@ def create_dropout_bayesian_wrapper(
elif mode == 'gauss':
net = NetworkBayesGauss(model, sigma)

else:
raise ValueError("Mode should be one of ('basic', 'beta', 'gauss').")

return net
Loading

0 comments on commit e36fff3

Please sign in to comment.