Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling issue: https://github.com/pycroscopy/atomai/issues/89 #90

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions atomai/losses_metrics/vi_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,54 @@ def rvae_loss(recon_loss: str,
kl_div = infocapacity(kl_div, capacity, num_iter=num_iter)
return likelihood - kl_div

def rvae_loss_lvae(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
y: torch.Tensor = None,
*args: torch.Tensor,
**kwargs: Union[List[float], float]
) -> torch.Tensor:
"""
Calculates ELBO
"""
if len(args) == 2:
z_mean, z_logsd = args
else:
raise ValueError(
"Pass mean and SD values of encoded distribution as args")

phi_prior = kwargs.get("phi_prior", 0.1)
capacity = kwargs.get("capacity")
num_iter = kwargs.get("num_iter", 0)

phi_logsd = z_logsd[:, 0]
z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:]

# Reconstruction loss
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()

# KL divergence
kl_rot = kld_rot(phi_prior, phi_logsd).mean()
kl_z = kld_normal([z_mean, z_logsd]).mean()
kl_div = (kl_z + kl_rot)
if capacity is not None:
kl_div = infocapacity(kl_div, capacity, num_iter=num_iter)

# Custom loss
if y is not None:
batch_size = x.size(0)
y_true = y[:batch_size].float()
custom_loss = torch.mean((z_mean[:, -1] - y_true) ** 2) * 10# last latent variable
else:
custom_loss = torch.tensor(0.0, device=x.device)

# print(f"Likelihood: {likelihood}, KL: {kl_div}, Custom: {custom_loss}")
# Total loss
# intuition: likelihood shoul dbe maximize, kl_loss and custon loss should be reduced
total_loss = likelihood - kl_div - custom_loss
return total_loss


def joint_vae_loss(recon_loss: str,
in_dim: Tuple[int],
Expand Down
221 changes: 221 additions & 0 deletions atomai/models/dgm/L_rvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""
rvae.py
=======

Module for analysis of system "building blocks" with
rotationally-invariant variational autoencoders

Created by Maxim Ziatdinov (email: [email protected])
"""

from copy import deepcopy as dc
from typing import Optional, Union, List

import numpy as np
import torch

from ...losses_metrics.vi_losses import rvae_loss_lvae
from ...utils import set_train_rng, to_onehot, transform_coordinates
from .vae_for_lvae import BaseVAE


class LrVAE(BaseVAE):
"""
Implements rotationally and translationally invariant
Variational Autoencoder (VAE) based on the idea of "spatial decoder"
by Bepler et al. in arXiv:1909.11663. In addition, this class allows
implementating the class-conditioned VAE and skip-VAE (arXiv:1807.04863)
with rotational and translational variance.

Args:
in_dim:
Input dimensions for image data passed as (heigth, width)
for grayscale data or (height, width, channels)
for multichannel data
latent_dim:
Number of VAE latent dimensions associated with image content
nb_classes:
Number of classes for class-conditional rVAE
translation:
account for xy shifts of image content (Default: True)
seed:
seed for torch and numpy (pseudo-)random numbers generators
**conv_encoder (bool):
use convolutional layers in encoder
**numlayers_encoder (int):
number of layers in encoder (Default: 2)
**numlayers_decoder (int):
number of layers in decoder (Default: 2)
**numhidden_encoder (int):
number of hidden units OR conv filters in encoder (Default: 128)
**numhidden_decoder (int):
number of hidden units in decoder (Default: 128)
**skip (bool):
uses generative skip model with residual paths between
latents and decoder layers (Default: False)

Example:

>>> input_dim = (28, 28) # input dimensions
>>> # Intitialize model
>>> rvae = aoi.models.rVAE(input_dim)
>>> # Train
>>> rvae.fit(imstack_train, training_cycles=100,
batch_size=100, rotation_prior=np.pi/2)
>>> rvae.manifold2d(origin="upper", cmap="gnuplot2")

One can also pass labels to train a class-conditioned rVAE

>>> # Intitialize model
>>> rvae = aoi.models.rVAE(input_dim, nb_classes=10)
>>> # Train
>>> rvae.fit(imstack_train, labels_train, training_cycles=100,
>>> batch_size=100, rotation_prior=np.pi/2)
>>> # Visualize learned manifold for class 1
>>> rvae.manifold2d(label=1, origin="upper", cmap="gnuplot2")
"""

def __init__(self,
in_dim: int = None,
latent_dim: int = 2,
nb_classes: int = 0,
translation: bool = True,
seed: int = 0,
**kwargs: Union[int, bool, str]
) -> None:
"""
Initializes rVAE model
"""
coord = 3 if translation else 1 # xy translations and/or rotation
args = (in_dim, latent_dim, nb_classes, coord)
super(LrVAE, self).__init__(*args, **kwargs)
set_train_rng(seed)
self.translation = translation
self.dx_prior = None
self.phi_prior = None
self.kdict_ = dc(kwargs)
self.kdict_["num_iter"] = 0

def elbo_fn(self,
x: torch.Tensor,
x_reconstr: torch.Tensor,
y: Optional[torch.Tensor] = None,
*args: torch.Tensor,
**kwargs: Union[List, float, int]
) -> torch.Tensor:
"""
Computes ELBO
"""
return rvae_loss_lvae(self.loss, self.in_dim, x, x_reconstr, y, *args, **kwargs)

def forward_compute_elbo(self,
x: torch.Tensor,
y: Optional[torch.Tensor] = None,
mode: str = "train"
) -> torch.Tensor:
"""
rVAE's forward pass with training/test loss computation
"""
x_coord_ = self.x_coord.expand(x.size(0), *self.x_coord.size())
if mode == "eval":
with torch.no_grad():
z_mean, z_logsd = self.encoder_net(x)
else:
z_mean, z_logsd = self.encoder_net(x)
self.kdict_["num_iter"] += 1
z_sd = torch.exp(z_logsd)
z = self.reparameterize(z_mean, z_sd)
phi = z[:, 0] # angle
if self.translation:
dx = z[:, 1:3] # translation
dx = (dx * self.dx_prior).unsqueeze(1)
z = z[:, 3:] # image content
else:
dx = 0 # no translation
z = z[:, 1:] # image content

# if y is not None:
# targets = to_onehot(y, self.nb_classes)
# z = torch.cat((z, targets), -1)
#print("y is the part in linear vae which will contrain the other latent variable")

x_coord_ = transform_coordinates(x_coord_, phi, dx)
if mode == "eval":
with torch.no_grad():
x_reconstr = self.decoder_net(x_coord_, z)
else:
x_reconstr = self.decoder_net(x_coord_, z)

return self.elbo_fn(x, x_reconstr, y, z_mean, z_logsd, **self.kdict_)

def fit(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Optional[Union[np.ndarray, torch.Tensor]] = None,
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
loss: str = "mse",
**kwargs) -> None:
"""
Trains rVAE model

Args:
X_train:
3D or 4D stack of training images with dimensions
(n_images, height, width) for grayscale data or
or (n_images, height, width, channels) for multi-channel data
y_train:
Vector with labels of dimension (n_images,), where n_images
is a number of training images
X_test:
3D or 4D stack of test images with the same dimensions
as for the X_train (Default: None)
y_test:
Vector with labels of dimension (n_images,), where n_images
is a number of test images
loss:
reconstruction loss function, "ce" or "mse" (Default: "mse")
**translation_prior (float):
translation prior
**rotation_prior (float):
rotational prior
**capacity (list):
List containing (max_capacity, num_iters, gamma) parameters
to control the capacity of the latent channel.
Based on https://arxiv.org/pdf/1804.03599.pdf
**filename (str):
file path for saving model aftereach training cycle ("epoch")
**recording (bool):
saves a learned 2d manifold at each training step
"""
self._check_inputs(X_train, y_train, X_test, y_test)
self.dx_prior = kwargs.get("translation_prior", 0.1)
self.kdict_["phi_prior"] = kwargs.get("rotation_prior", 0.1)
for k, v in kwargs.items():
if k in ["capacity"]:
self.kdict_[k] = v
self.compile_trainer(
(X_train, y_train), (X_test, y_test), **kwargs)
self.loss = loss # this part needs to be handled better
if self.loss == "ce":
self.sigmoid_out = True # Use sigmoid layer for "prediction" stage
self.metadict["sigmoid_out"] = True
self.recording = kwargs.get("recording", False)

for e in range(self.training_cycles):
self.current_epoch = e
elbo_epoch = self.train_epoch()
self.loss_history["train_loss"].append(elbo_epoch)
if self.test_iterator is not None:
elbo_epoch_test = self.evaluate_model()
self.loss_history["test_loss"].append(elbo_epoch_test)
self.print_statistics(e)
self.update_metadict()
if self.recording and self.z_dim in [3, 5]:
self.manifold2d(savefig=True, filename=str(e))
self.save_model(self.filename)
if self.recording and self.z_dim in [3, 5]:
self.visualize_manifold_learning("./vae_learning")

def update_metadict(self):
self.metadict["num_epochs"] = self.current_epoch
self.metadict["num_iter"] = self.kdict_["num_iter"]
31 changes: 31 additions & 0 deletions atomai/models/dgm/rvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,37 @@ def forward_compute_elbo(self,

return self.elbo_fn(x, x_reconstr, z_mean, z_logsd, **self.kdict_)

def decoder_temp(self, z_mean: Union[np.ndarray, torch.Tensor],
y: Optional[Union[int, np.ndarray, torch.Tensor]] = None) -> torch.Tensor:
"""
### to be added: the conditional decoding
Takes a point in the latent space and decodes it into an image via the decoder

Args:
z_mean (Union[np.ndarray, torch.Tensor]): 5 dimentional

Returns:
torch.Tensor: _description_

Example:
>>> z_mean, z_std = rvae.encode(norm_patches)
>>> decoded_patches = rvae.decoder_temp(z_mean)

"""
z1, z2, z3 = z_mean[:,0], z_mean[:, 1:3], z_mean[:, 3:]
x_coord_ = self.x_coord.expand(z_mean.shape[0], *self.x_coord.size()).cpu()
phi = torch.tensor(z1) # angle
dx = torch.tensor(z2) # translation
dx = (dx * self.dx_prior).unsqueeze(1)
x_coord_ = transform_coordinates(x_coord_, phi, dx)
z_sample = torch.tensor(z3)
if self.device == "cuda":
x_coord_ = x_coord_.cuda()
z_sample = z_sample.cuda()
x_decoded = self.decoder_net(x_coord_, z_sample)
imdec = x_decoded.detach().cpu().numpy()
return imdec

def fit(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Optional[Union[np.ndarray, torch.Tensor]] = None,
Expand Down
Loading
Loading