Menu

PyTorch toolbox to work with spherical surfaces.

Source code for surfify.models.unet

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
The spherical UNet architecture.
"""

# Imports
import numpy as np
import torch
import torch.nn as nn
from joblib import Memory
from ..utils import number_of_ico_vertices, get_logger, debug_msg
from ..nn import (
    IcoUpConv, IcoMaxIndexUpSample, IcoFixIndexUpSample, IcoUpSample, IcoPool,
    IcoSpMaConv, IcoSpMaConvTranspose)
from .base import SphericalBase


# Global parameters
logger = get_logger()


[docs]class SphericalUNet(SphericalBase): """ The Spherical U-Net architecture. The architecture is built upon specific spherical surface convolution, pooling, and transposed convolution modules. It has an encoder path and a decoder path, with a user-defined resolution steps. Different from the standard U-Net, all 3×3 convolution are replaced with the RePa or DiNe convolution, 2×2 up-convolution with surface transposed convolution or surface upsampling, and 2×2 max pooling with surface max/mean pooling. In addition to the standard U-Net, before each convolution layer’s rectified linear units (ReLU) activation function, a batch normalization layer is added. At the final layer, 1×1 convolution is replaced by vertex-wise filter. The number of feature channels are double after each surface pooling layer and halve at each transposed convolution or up sampling layer. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalGUNet Examples -------- >>> import torch >>> from surfify.models import SphericalUNet >>> from surfify.utils import icosahedron >>> vertices, triangles = icosahedron(order=2) >>> model = SphericalUNet( in_order=2, in_channels=2, out_channels=4, depth=2, start_filts=8, conv_mode="DiNe", dine_size=1, up_mode="interp", standard_ico=False) >>> x = torch.zeros((10, 2, len(vertices))) >>> out = model(x) >>> out.shape References ---------- Zhao F, et al., Spherical U-Net on Cortical Surfaces: Methods and Applications, IPMI, 2019. """
[docs] def __init__(self, in_order, in_channels, out_channels, depth=5, start_filts=32, conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5, dynamic_repa_zoom=False, up_mode="interp", standard_ico=False, cachedir=None): """ Init SphericalUNet. Parameters ---------- in_order: int the input icosahedron order. in_channels: int input features/channels. out_channels: int output features/channels. depth: int, default 5 number of layers in the UNet. start_filts: int, default 32 number of convolutional filters for the first conv. conv_mode: str, default 'DiNe' use either 'RePa' - Rectangular Patch convolution method or 'DiNe' - 1 ring Direct Neighbor convolution method.. dine_size: int, default 1 the size of the spherical convolution filter, ie. the number of neighbor rings to be considered. repa_size: int, default 5 the size of the rectangular grid in the tangent space. repa_zoom: int, default 5 control the rectangular grid spacing in the tangent space by applying a multiplicative factor of `1 / repa_zoom`. dynamic_repa_zoom: bool, default False dynamically adapt the RePa zoom by applying a multiplicative factor of `log(order + 1) + 1`. up_mode: str, default 'interp' type of upsampling: 'transpose' for transpose convolution (1 ring), 'interp' for nearest neighbor linear interpolation, 'maxpad' for max pooling shifted zero padding, and 'zeropad' for classical zero padding. standard_ico: bool, default False optionaly use surfify tesselation. cachedir: str, default None set this folder to use smart caching speedup. """ logger.debug("SphericalUNet init...") super(SphericalUNet, self).__init__( input_order=in_order, n_layers=depth, conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size, repa_zoom=repa_zoom, dynamic_repa_zoom=dynamic_repa_zoom, standard_ico=standard_ico, cachedir=cachedir) self.memory = Memory(cachedir, verbose=0) self.in_order = in_order self.depth = depth self.in_vertices = number_of_ico_vertices(order=in_order) self.in_channels = in_channels self.out_channels = out_channels self.up_mode = up_mode self.filts = [in_channels] + [ start_filts * 2 ** idx for idx in range(depth)] logger.debug("- filters: {0}".format(self.filts)) for idx in range(depth): order = self.in_order - idx logger.debug( "- DownBlock {0}: {1} -> {2} [{3} - {4} - {5}]".format( idx, self.filts[idx], self.filts[idx + 1], self.ico[order].neighbor_indices.shape, (None if idx == 0 else self.ico[order + 1].neighbor_indices.shape), (None if idx == 0 else self.ico[order + 1].down_indices.shape))) block = DownBlock( conv_layer=self.sconv, in_ch=self.filts[idx], out_ch=self.filts[idx + 1], conv_neigh_indices=self.ico[order].conv_neighbor_indices, down_neigh_indices=( None if idx == 0 else self.ico[order + 1].neighbor_indices), down_indices=( None if idx == 0 else self.ico[order + 1].down_indices), pool_mode=("max" if self.up_mode == "maxpad" else "mean"), first=(True if idx == 0 else False)) setattr(self, "down{0}".format(idx + 1), block) cnt = 1 for idx in range(depth - 1, 0, -1): logger.debug("- UpBlock {0}: {1} -> {2} [{3} - {4}]".format( cnt, self.filts[idx + 1], self.filts[idx], self.ico[order + 1].neighbor_indices.shape, self.ico[order].up_indices.shape)) block = UpBlock( conv_layer=self.sconv, in_ch=self.filts[idx + 1], out_ch=self.filts[idx], conv_neigh_indices=self.ico[order + 1].conv_neighbor_indices, neigh_indices=self.ico[order + 1].neighbor_indices, up_neigh_indices=self.ico[order].up_indices, down_indices=self.ico[order + 1].down_indices, up_mode=self.up_mode) setattr(self, "up{0}".format(cnt), block) order += 1 cnt += 1 logger.debug("- FC: {0} -> {1}".format(self.filts[1], out_channels)) self.fc = nn.Sequential( nn.Linear(self.filts[1], out_channels))
[docs] def forward(self, x): """ Forward method. """ logger.debug("SphericalUNet...") logger.debug(debug_msg("input", x)) if x.size(2) != self.in_vertices: raise RuntimeError("Input data must be projected on an {0} order " "icosahedron.".format(self.in_order)) encoder_outs = [] pooling_outs = [] for idx in range(1, self.depth + 1): down_block = getattr(self, "down{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, down_block)) x, max_pool_indices = down_block(x) encoder_outs.append(x) pooling_outs.append(max_pool_indices) encoder_outs = encoder_outs[::-1] pooling_outs = pooling_outs[::-1] for idx in range(1, self.depth): up_block = getattr(self, "up{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, up_block)) x_up = encoder_outs[idx] max_pool_indices = pooling_outs[idx - 1] x = up_block(x, x_up, max_pool_indices) logger.debug("FC...") logger.debug(debug_msg("input", x)) n_samples = len(x) x = x.permute(0, 2, 1) x = x.reshape(n_samples * self.in_vertices, self.filts[1]) x = self.fc(x) x = x.view(n_samples, self.in_vertices, self.out_channels) x = x.permute(0, 2, 1) logger.debug(debug_msg("output", x)) return x
[docs]class DownBlock(nn.Module): """ Downsampling block in spherical UNet: mean pooling => (conv => BN => ReLU) * 2 """
[docs] def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices, down_neigh_indices, down_indices, pool_mode="mean", first=False): """ Init DownBlock. Parameters ---------- conv_layer: nn.Module the convolutional layer on icosahedron discretized sphere. in_ch: int input features/channels. out_ch: int output features/channels. conv_neigh_indices: array conv layer's filters' neighborhood indices at sampling i. down_neigh_indices: array conv layer's filters' neighborhood indices at sampling i + 1. down_indices: array downsampling indices at sampling i. pool_mode: str, default 'mean' the pooling mode: 'mean' or 'max'. first: bool, default False if set skip the pooling block. """ super(DownBlock, self).__init__() self.first = first if not first: self.pooling = IcoPool( down_neigh_indices, down_indices, pool_mode) self.double_conv = nn.Sequential( conv_layer(in_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), conv_layer(out_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): """ Forward method. """ logger.debug("- DownBlock") logger.debug(debug_msg("input", x)) max_pool_indices = None if not self.first: x, max_pool_indices = self.pooling(x) logger.debug(debug_msg("pooling", x)) if max_pool_indices is not None: logger.debug(debug_msg("max pooling indices", max_pool_indices)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x, max_pool_indices
[docs]class UpBlock(nn.Module): """ Define the upsamping block in spherical UNet: upconv => (conv => BN => ReLU) * 2 """
[docs] def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices, neigh_indices, up_neigh_indices, down_indices, up_mode): """ Init UpBlock. Parameters ---------- conv_layer: nn.Module the convolutional layer on icosahedron discretized sphere. in_ch: int input features/channels. out_ch: int output features/channels. conv_neigh_indices: tensor, int conv layer's filters' neighborhood indices at sampling i. neigh_indices: tensor, int neighborhood indices at sampling i. up_neigh_indices: array upsampling neighborhood indices at sampling i + 1. down_indices: array downsampling indices at sampling i. up_mode: str, default 'interp' type of upsampling: 'transpose' for transpose convolution, 'interp' for nearest neighbor linear interpolation, 'maxpad' for max pooling shifted zero padding, and 'zeropad' for classical zero padding. """ super(UpBlock, self).__init__() self.up_mode = up_mode if up_mode == "interp": self.up = IcoUpSample(in_ch, out_ch, up_neigh_indices) elif up_mode == "zeropad": self.up = IcoFixIndexUpSample(in_ch, out_ch, up_neigh_indices) elif up_mode == "maxpad": self.up = IcoMaxIndexUpSample( in_ch, out_ch, neigh_indices, down_indices) elif up_mode == "transpose": self.up = IcoUpConv( in_ch, out_ch, neigh_indices, down_indices) else: raise ValueError("Invalid upsampling method.") self.double_conv = nn.Sequential( conv_layer(in_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), conv_layer(out_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x1, x2, max_pool_indices): """ Forward method. """ logger.debug("- UpBlock") logger.debug(debug_msg("input", x1)) logger.debug(debug_msg("skip", x2)) if self.up_mode == "maxpad": x1 = self.up(x1, max_pool_indices) else: x1 = self.up(x1) logger.debug(debug_msg("upsampling", x1)) x = torch.cat((x1, x2), 1) logger.debug(debug_msg("cat", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x
[docs]class SphericalGUNet(nn.Module): """ The Spherical Grided U-Net architecture. The architecture is built upon specific spherical surface convolution, pooling, and transposed convolution modules. It has an encoder path and a decoder path, with a user-defined resolution steps. Different from the standard U-Net, all 3×3 convolution are replaced with the SpMa convolution. In addition to the standard U-Net, before each convolution layer’s rectified linear units (ReLU) activation function, a batch normalization layer is added. The number of feature channels are double after each surface pooling layer and halve at each transposed convolution or up sampling layer. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalUNet References ---------- Zhao F, et al., Spherical U-Net on Cortical Surfaces: Methods and Applications, IPMI, 2019. """
[docs] def __init__(self, in_channels, out_channels, input_dim=192, depth=5, start_filts=32): """ Init SphericalUNet. Parameters ---------- in_channels: int input features/channels. out_channels: int output features/channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. depth: int, default 5 number of layers in the UNet. start_filts: int, default 32 number of convolutional filters for the first conv. """ logger.debug("SphericalGUNet init...") super(SphericalGUNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.input_dim = input_dim self.depth = depth self.start_filts = start_filts self.filts = [in_channels] + [ start_filts * 2 ** idx for idx in range(depth)] logger.debug("- filters: {0}".format(self.filts)) for idx in range(depth): logger.debug( "- DownGBlock {0}: {1} -> {2}".format( idx, self.filts[idx], self.filts[idx + 1])) block = DownGBlock( in_ch=self.filts[idx], out_ch=self.filts[idx + 1], first=(True if idx == 0 else False)) setattr(self, "down{0}".format(idx + 1), block) cnt = 1 for idx in range(depth - 1, 0, -1): logger.debug("- UpGBlock {0}: {1} -> {2}".format( cnt, self.filts[idx + 1], self.filts[idx])) block = UpGBlock( in_ch=self.filts[idx + 1], out_ch=self.filts[idx]) setattr(self, "up{0}".format(cnt), block) cnt += 1 logger.debug("- Conv 1x1 final: {0} -> {1}".format( self.filts[1], out_channels)) self.conv_final = nn.Conv2d( self.filts[1], out_channels, kernel_size=1, groups=1, stride=1)
[docs] def forward(self, x): """ Forward method. """ logger.debug("SphericalGUNet...") logger.debug(debug_msg("input", x)) encoder_outs = [] for idx in range(1, self.depth + 1): down_block = getattr(self, "down{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, down_block)) x = down_block(x) encoder_outs.append(x) encoder_outs = encoder_outs[::-1] for idx in range(1, self.depth): up_block = getattr(self, "up{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, up_block)) x_up = encoder_outs[idx] x = up_block(x, x_up) x = self.conv_final(x) logger.debug(debug_msg("output", x)) return x
[docs]class DownGBlock(nn.Module): """ Downsampling block in grided spherical UNet: max pooling => (conv => BN => ReLU) * 2 """
[docs] def __init__(self, in_ch, out_ch, first=False): """ Init DownGBlock. Parameters ---------- in_ch: int input features/channels. out_ch: int output features/channels. first: bool, default False if set skip the pooling block. """ super(DownGBlock, self).__init__() self.first = first if not first: self.pooling = nn.MaxPool2d(kernel_size=2, stride=2) self.double_conv = nn.Sequential( IcoSpMaConv(in_feats=in_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), IcoSpMaConv(in_feats=out_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): """ Forward method. """ logger.debug("- DownGBlock") logger.debug(debug_msg("input", x)) if not self.first: x = self.pooling(x) logger.debug(debug_msg("pooling", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x
[docs]class UpGBlock(nn.Module): """ Define the upsamping block in grided spherical UNet: upconv => (conv => BN => ReLU) * 2 """
[docs] def __init__(self, in_ch, out_ch): """ Init UpGBlock. Parameters ---------- in_ch: int input features/channels. out_ch: int output features/channels. """ super(UpGBlock, self).__init__() self.up = IcoSpMaConvTranspose( in_feats=in_ch, out_feats=out_ch, kernel_size=4, stride=2, pad=1, zero_pad=3) self.double_conv = nn.Sequential( IcoSpMaConv(in_feats=in_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), IcoSpMaConv(in_feats=out_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x1, x2): """ Forward method. """ logger.debug("- UpGBlock") logger.debug(debug_msg("input", x1)) logger.debug(debug_msg("skip", x2)) x1 = self.up(x1) logger.debug(debug_msg("upsampling", x1)) x = torch.cat((x1, x2), 1) logger.debug(debug_msg("cat", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x

Follow us

© 2023, surfify developers