Menu

PyTorch toolbox to work with spherical surfaces.

Source code for surfify.models.vae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Cortical Spherical Variational Auto-Encoder (GMVAE) models.

[1] Representation Learning of Resting State fMRI with Variational
Autoencoder: https://github.com/libilab/rsfMRI-VAE
"""

# Imports
import torch
import torch.nn as nn
from torch.distributions import Normal
from ..utils import get_logger, debug_msg
from ..nn import IcoUpConv, IcoPool, IcoSpMaConv, IcoSpMaConvTranspose
from .base import SphericalBase


# Global parameters
logger = get_logger()


[docs]class SphericalVAE(SphericalBase): """ Spherical VAE architecture. Use either RePa - Rectangular Patch convolution method or DiNe - Direct Neighbor convolution method. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalGVAE References ---------- Representation Learning of Resting State fMRI with Variational Autoencoder, NeuroImage 2021. """
[docs] def __init__(self, input_channels=1, input_order=5, latent_dim=64, conv_flts=[32, 32, 64, 64], conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5, standard_ico=False, cachedir=None): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_order: int, default 5 the input icosahedron order. latent_dim: int, default 64 the size of the stochastic latent state of the SVAE. conv_flts: list of int the size of convolutional filters. conv_mode: str, default 'DiNe' use either 'RePa' - Rectangular Patch convolution method or 'DiNe' - 1 ring Direct Neighbor convolution method. dine_size: int, default 1 the size of the spherical convolution filter, ie. the number of neighbor rings to be considered. repa_size: int, default 5 the size of the rectangular grid in the tangent space. repa_zoom: int, default 5 a multiplicative factor applied to the rectangular grid in the tangent space. standard_ico: bool, default False optionaly use surfify tesselation. cachedir: str, default None set this folder to use smart caching speedup. """ logger.debug("SphericalVAE init...") super(SphericalVAE, self).__init__( input_order=input_order, n_layers=len(conv_flts), conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size, repa_zoom=repa_zoom, standard_ico=standard_ico, cachedir=cachedir) self.input_channels = input_channels self.latent_dim = latent_dim self.conv_flts = conv_flts self.top_flatten_dim = len( self.ico[self.input_order - self.n_layers + 1].vertices) self.top_final = self.conv_flts[-1] * self.top_flatten_dim # define the encoder self.enc_left_conv = self.sconv( input_channels, int(self.conv_flts[0] / 2), self.ico[self.input_order].conv_neighbor_indices) self.enc_right_conv = self.sconv( input_channels, int(self.conv_flts[0] / 2), self.ico[self.input_order].conv_neighbor_indices) self.enc_w_conv = nn.Sequential() for idx in range(1, self.n_layers): order = self.input_order - idx pooling = IcoPool( down_neigh_indices=self.ico[order + 1].neighbor_indices, down_indices=self.ico[order + 1].down_indices, pooling_type="mean") self.enc_w_conv.add_module("pooling_{0}".format(idx), pooling) conv = self.sconv( self.conv_flts[idx - 1], self.conv_flts[idx], self.ico[order].conv_neighbor_indices) self.enc_w_conv.add_module("down_{0}".format(idx), conv) self.enc_w_dense = nn.Linear(self.top_final, self.latent_dim * 2) # define the decoder self.dec_w_dense = nn.Linear(self.latent_dim, self.top_final) self.dec_w_conv = nn.Sequential() cnt = 1 for idx in range(self.n_layers - 1, 0, -1): tconv = IcoUpConv( in_feats=self.conv_flts[idx], out_feats=self.conv_flts[idx - 1], up_neigh_indices=self.ico[order + 1].neighbor_indices, down_indices=self.ico[order + 1].down_indices) self.dec_w_conv.add_module("up_{0}".format(cnt), tconv) order += 1 cnt += 1 self.dec_left_conv = IcoUpConv( in_feats=int(self.conv_flts[0] / 2), out_feats=self.input_channels, up_neigh_indices=self.ico[order].neighbor_indices, down_indices=self.ico[order].down_indices) self.dec_right_conv = IcoUpConv( in_feats=int(self.conv_flts[0] / 2), out_feats=self.input_channels, up_neigh_indices=self.ico[order].neighbor_indices, down_indices=self.ico[order].down_indices) self.relu = nn.ReLU(inplace=True)
[docs] def encode(self, left_x, right_x): """ The encoder. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- q(z | x): Normal (batch_size, <latent_dim>) a Normal distribution. """ x = torch.cat( (self.enc_left_conv(left_x), self.enc_right_conv(right_x)), dim=1) x = self.relu(x) for layer_idx in range((self.n_layers - 1) * 2): if isinstance(self.enc_w_conv[layer_idx], IcoPool): x = self.enc_w_conv[layer_idx](x)[0] else: x = self.relu(self.enc_w_conv[layer_idx](x)) x = x.reshape(-1, self.top_final) x = self.enc_w_dense(x) z_mu, z_logvar = torch.chunk(x, chunks=2, dim=1) return Normal(loc=z_mu, scale=z_logvar.exp().pow(0.5))
[docs] def decode(self, z): """ The decoder. Parameters ---------- z: Tensor (samples, <latent_dim>) the stochastic latent state z. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ x = self.relu(self.dec_w_dense(z)) x = x.view(-1, self.conv_flts[-1], self.top_flatten_dim) for layer_idx in range(self.n_layers - 1): x = self.relu(self.dec_w_conv[layer_idx](x)) left_recon_x, right_recon_x = torch.chunk(x, chunks=2, dim=1) left_recon_x = self.dec_left_conv(left_recon_x) right_recon_x = self.dec_right_conv(right_recon_x) return left_recon_x, right_recon_x
[docs] def reparameterize(self, q): """ Implement the reparametrization trick. """ if self.training: z = q.rsample() else: z = q.loc return z
[docs] def forward(self, left_x, right_x): """ The forward method. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ logger.debug("SphericalVAE forward pass") logger.debug(debug_msg("left cortical", left_x)) logger.debug(debug_msg("right cortical", right_x)) q = self.encode(left_x, right_x) logger.debug(debug_msg("posterior loc", q.loc)) logger.debug(debug_msg("posterior scale", q.scale)) z = self.reparameterize(q) logger.debug(debug_msg("z", z)) left_recon_x, right_recon_x = self.decode(z) logger.debug(debug_msg("left recon cortical", left_recon_x)) logger.debug(debug_msg("right recon cortical", right_recon_x)) return left_recon_x, right_recon_x, {"q": q, "z": z}
[docs]class SphericalGVAE(nn.Module): """ Spherical Grided VAE architecture. Use SpMa - Spherical Mapping convolution method. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalVAE References ---------- Representation Learning of Resting State fMRI with Variational Autoencoder, NeuroImage 2021. """
[docs] def __init__(self, input_channels=1, input_dim=192, latent_dim=64, conv_flts=[64, 128, 128, 256, 256]): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. latent_dim: int, default 64 the size of the stochastic latent state of the SVAE. conv_flts: list of int the size of convolutional filters. """ logger.debug("SphericalGVAE init...") super(SphericalGVAE, self).__init__() self.input_channels = input_channels self.input_dim = input_dim self.latent_dim = latent_dim self.conv_flts = conv_flts self.n_layers = len(self.conv_flts) self.top_flatten_dim = int(self.input_dim / (2 ** self.n_layers)) self.top_final = self.conv_flts[-1] * self.top_flatten_dim ** 2 # define the encoder self.enc_left_conv = IcoSpMaConv( in_feats=self.input_channels, out_feats=int(self.conv_flts[0] / 2), kernel_size=8, stride=2, pad=3) self.enc_right_conv = IcoSpMaConv( in_feats=self.input_channels, out_feats=int(self.conv_flts[0] / 2), kernel_size=8, stride=2, pad=3) self.enc_w_conv = nn.ModuleList([ IcoSpMaConv(self.conv_flts[i - 1], self.conv_flts[i], kernel_size=4, stride=2, pad=1) for i in range(1, self.n_layers)]) self.enc_w_dense = nn.Linear(self.top_final, self.latent_dim * 2) # define the decoder self.dec_w_dense = nn.Linear(self.latent_dim, self.top_final) self.dec_w_conv = nn.ModuleList([ IcoSpMaConvTranspose( in_feats=self.conv_flts[i], out_feats=self.conv_flts[i - 1], kernel_size=4, stride=2, pad=1, zero_pad=3) for i in range(self.n_layers - 1, 0, -1)]) self.dec_left_conv = IcoSpMaConvTranspose( in_feats=int(self.conv_flts[0] / 2), out_feats=self.input_channels, kernel_size=8, stride=2, pad=3, zero_pad=9) self.dec_right_conv = IcoSpMaConvTranspose( in_feats=int(self.conv_flts[0] / 2), out_feats=self.input_channels, kernel_size=8, stride=2, pad=3, zero_pad=9) self.relu = nn.ReLU(inplace=True)
[docs] def encode(self, left_x, right_x): """ The encoder. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- q(z | x): Normal (batch_size, <latent_dim>) a Normal distribution. """ x = torch.cat( (self.enc_left_conv(left_x), self.enc_right_conv(right_x)), dim=1) x = self.relu(x) for layer_idx in range(self.n_layers - 1): x = self.relu(self.enc_w_conv[layer_idx](x)) x = x.view(-1, self.top_final) x = self.enc_w_dense(x) z_mu, z_logvar = torch.chunk(x, chunks=2, dim=1) return Normal(loc=z_mu, scale=z_logvar.exp().pow(0.5))
[docs] def decode(self, z): """ The decoder. Parameters ---------- z: Tensor (samples, <latent_dim>) the stochastic latent state z. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ x = self.relu(self.dec_w_dense(z)) x = x.view(-1, self.conv_flts[-1], self.top_flatten_dim, self.top_flatten_dim) for layer_idx in range(self.n_layers - 1): x = self.relu(self.dec_w_conv[layer_idx](x)) left_recon_x, right_recon_x = torch.chunk(x, chunks=2, dim=1) left_recon_x = self.dec_left_conv(left_recon_x) right_recon_x = self.dec_right_conv(right_recon_x) return left_recon_x, right_recon_x
[docs] def reparameterize(self, q): """ Implement the reparametrization trick. """ if self.training: z = q.rsample() else: z = q.loc return z
[docs] def forward(self, left_x, right_x): """ The forward method. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ logger.debug("SphericalGVAE forward pass") logger.debug(debug_msg("left cortical", left_x)) logger.debug(debug_msg("right cortical", right_x)) q = self.encode(left_x, right_x) logger.debug(debug_msg("posterior loc", q.loc)) logger.debug(debug_msg("posterior scale", q.scale)) z = self.reparameterize(q) logger.debug(debug_msg("z", z)) left_recon_x, right_recon_x = self.decode(z) logger.debug(debug_msg("left recon cortical", left_recon_x)) logger.debug(debug_msg("right recon cortical", right_recon_x)) return left_recon_x, right_recon_x, {"q": q, "z": z}

Follow us

© 2021, surfify developers