Menu

PyTorch toolbox to work with spherical surfaces.

Source code for surfify.augmentation.augmentation

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Surface augmentation tools.
"""

# Imports
import numbers
import datetime
import numpy as np
from surfify.utils import (
    neighbors, rotate_data, find_neighbors, find_rotation_interpol_coefs)
from surfify.utils.io import compute_and_store


[docs]class SphericalRandomRotation(object): """ Rotation of the icosahedron's vertices. See Also -------- rotate_data Examples -------- >>> from surfify.utils import icosahedron >>> from surfify.datasets import make_classification >>> from surfify.augmentation import SphericalRandomRotation >>> import matplotlib.pyplot as plt >>> from surfify.plotting import plot_trisurf >>> ico3_verts, ico3_tris = icosahedron(order=3) >>> X, y = make_classification(ico3_verts, n_samples=1, n_classes=3, scale=1, seed=42) >>> processor = SphericalRandomRotation( ico3_verts, ico3_tris, angles=(45, 0, 0)) >>> y_rot = processor(y) >>> plot_trisurf(ico3_verts, triangles=ico3_tris, texture=y_rot, is_label=False) >>> plt.show() """
[docs] def __init__(self, vertices, triangles, angles=(5, 0, 0), fixed_angle=True, interpolation="barycentric", cachedir=None): """ Init class. Parameters ---------- vertices: array (N, 3) icosahedron's vertices. triangles: array (M, 3) icosahdron's triangles. angles: 3-uplet, default (5, 0, 0) the rotation angles intervals in degrees for each axis (Euler representation). fixed_angle: bool, default True if True changes the angle of the rotation at each call. This option slows down the training as the rotation needs to be initialiazed at each call interpolation: str, default 'barycentric' type of interpolation to use by the rotate_data function, see `rotate_data`. cachedir: str, default None set this folder to use smart caching speedup. """ self.vertices = vertices self.triangles = triangles self.angles = [interval(val) for val in angles] self.fixed_angle = fixed_angle if fixed_angle: self.angles = [ np.random.uniform(val[0], val[1]) for val in self.angles] self.interpolation = interpolation self.rotate_data_cached = compute_and_store( find_rotation_interpol_coefs, cachedir)(rotate_data)
def __call__(self, data): """ Rotates the provided vertices and projects the input data accordingly. Parameters ---------- data: array (N, ) input data/texture. Returns ------- rot_data: arr (N, ) rotated input data. """ np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond) angles = self.angles if not self.fixed_angle: angles = [np.random.uniform(val[0], val[1]) for val in self.angles] return self.rotate_data_cached( data[np.newaxis, :], self.vertices, self.triangles, angles).squeeze()
[docs]class SphericalRandomCut(object): """ Random cut of patches on the icosahedron: use Direct Neighbors (DiNe) to build patches. See Also -------- neighbors Examples -------- >>> from surfify.utils import icosahedron >>> from surfify.datasets import make_classification >>> from surfify.augmentation import SphericalRandomCut >>> import matplotlib.pyplot as plt >>> from surfify.plotting import plot_trisurf >>> ico3_verts, ico3_tris = icosahedron(order=3) >>> X, y = make_classification(ico3_verts, n_samples=1, n_classes=3, scale=1, seed=42) >>> processor = SphericalRandomCut( ico3_verts, ico3_tris, replacement_value=5) >>> y_cut = processor(y) >>> plot_trisurf(ico3_verts, triangles=ico3_tris, texture=y_cut, is_label=True) >>> plt.show() """
[docs] def __init__(self, vertices, triangles, neighs=None, patch_size=3, n_patches=1, replacement_value=0): """ Init class. Parameters ---------- vertices: array (N, 3) icosahedron's vertices. triangles: array (M, 3) icosahdron's triangles. neighs: dict, default None optionnaly specify the DiNe neighboors of each vertex as build with `sufify.utils.neighbors`, ie. a dictionary with vertices row index as keys and a dictionary of neighbors vertices row indexes organized by rings as values. n_rings: int, default 3 the number of neighboring rings from one node to be considered during the ablation. n_patches: int, default 1 the number of patches to be considered. replacement_value: float, default 0 the replacement patch value. """ self.vertices = vertices self.triangles = triangles if neighs is None: self.neighs = neighbors(vertices, triangles, direct_neighbor=True) else: self.neighs = neighs self.patch_size = patch_size self.n_patches = n_patches self.replacement_value = replacement_value
def __call__(self, data): """ Applies the cut out (ablation) augmentation to the data. Parameters ---------- data: array (N, ) input data/texture. Returns ------- cut_data: arr (N, ) ablated input data. """ data_cut = data.copy() for idx in range(self.n_patches): random_node = np.random.randint(0, len(self.vertices)) patch_indices = find_neighbors( random_node, self.patch_size, self.neighs) data_cut[patch_indices] = self.replacement_value return data_cut
[docs]def interval(bound, lower=None): """ Create an interval. Parameters ---------- bound: 2-uplet or number the object used to build the interval. lower: number, default None the lower bound of the interval. If not specified, a symetric interval is generated. Returns ------- interval: 2-uplet an interval. """ if isinstance(bound, numbers.Number): if bound < 0: raise ValueError("Specified interval value must be positive.") if lower is None: lower = -bound return (lower, bound) if len(bound) != 2: raise ValueError("Interval must be specified with 2 values.") min_val, max_val = bound if min_val > max_val: raise ValueError("Wrong interval boundaries.") return tuple(bound)

Follow us

© 2021, surfify developers