import random
from typing import Dict, List, Tuple
import numpy as np
import torch
from metatensor.torch import TensorBlock, TensorMap
from metatomic.torch import System
from scipy.spatial.transform import Rotation
from .data import TargetInfo
[docs]
def get_random_rotation():
return Rotation.random()
[docs]
def get_random_inversion():
return random.choice([1, -1])
[docs]
class RotationalAugmenter:
"""
A class to apply random rotations and inversions to a set of systems and their
targets.
:param target_info_dict: A dictionary mapping target names to their corresponding
:class:`TargetInfo` objects. This is used to determine the type of targets and
how to apply the augmentations.
"""
def __init__(self, target_info_dict: Dict[str, TargetInfo]):
# checks on targets
for target_info in target_info_dict.values():
if target_info.is_cartesian:
if len(target_info.layout.block(0).components) > 2:
raise ValueError(
"RotationalAugmenter only supports Cartesian targets "
"with `rank<=2`."
)
self.target_info_dict = target_info_dict
self.wigner = None
self.complex_to_real_spherical_harmonics_transforms = {}
is_any_target_spherical = any(
target_info.is_spherical for target_info in target_info_dict.values()
)
if is_any_target_spherical:
try:
import spherical
except ImportError:
# quaternionic (used below) is a dependency of spherical
raise ImportError(
"To use spherical targets with nanoPET, please install the "
"`spherical` package with `pip install spherical`."
)
largest_l = max(
(len(block.components[0]) - 1) // 2
for target_info in target_info_dict.values()
if target_info.is_spherical
for block in target_info.layout.blocks()
)
self.wigner = spherical.Wigner(largest_l)
for ell in range(largest_l + 1):
self.complex_to_real_spherical_harmonics_transforms[ell] = (
_complex_to_real_spherical_harmonics_transform(ell)
)
[docs]
def apply_random_augmentations(
self, systems: List[System], targets: Dict[str, TensorMap]
) -> Tuple[List[System], Dict[str, TensorMap]]:
"""
Apply a random augmentation to a number of ``System`` objects and its targets.
:param systems: A list of :class:`System` objects to be augmented.
:param targets: A dictionary mapping target names to their corresponding
:class:`TensorMap` objects. These are the targets to be augmented.
:return: A tuple containing the augmented systems and targets.
"""
rotations = [get_random_rotation() for _ in range(len(systems))]
inversions = [get_random_inversion() for _ in range(len(systems))]
transformations = [
torch.from_numpy(r.as_matrix() * i) for r, i in zip(rotations, inversions)
]
wigner_D_matrices = {}
if self.wigner is not None:
scipy_quaternions = [r.as_quat() for r in rotations]
quaternionic_quaternions = [
_scipy_quaternion_to_quaternionic(q) for q in scipy_quaternions
]
wigner_D_matrices_complex = [
self.wigner.D(q) for q in quaternionic_quaternions
]
for target_name in targets.keys():
target_info = self.target_info_dict[target_name]
if target_info.is_spherical:
for block in target_info.layout.blocks():
ell = (len(block.components[0]) - 1) // 2
if ell not in wigner_D_matrices: # skip if already computed
wigner_D_matrices_l = []
for wigner_D_matrix_complex in wigner_D_matrices_complex:
wigner_D_matrix = np.zeros(
(2 * ell + 1, 2 * ell + 1), dtype=np.complex128
)
for mp in range(-ell, ell + 1):
for m in range(-ell, ell + 1):
wigner_D_matrix[m + ell, mp + ell] = (
wigner_D_matrix_complex[
self.wigner.Dindex(ell, m, mp)
]
).conj()
U = self.complex_to_real_spherical_harmonics_transforms[
ell
]
wigner_D_matrix = U.conj() @ wigner_D_matrix @ U.T
assert np.allclose(wigner_D_matrix.imag, 0.0)
wigner_D_matrix = wigner_D_matrix.real
wigner_D_matrices_l.append(
torch.from_numpy(wigner_D_matrix)
)
wigner_D_matrices[ell] = wigner_D_matrices_l
return _apply_random_augmentations(
systems, targets, transformations, wigner_D_matrices
)
def _apply_wigner_D_matrices(
systems: List[System],
target_tmap: TensorMap,
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
) -> TensorMap:
new_blocks: List[TensorBlock] = []
for key, block in target_tmap.items():
ell, sigma = int(key[0]), int(key[1])
values = block.values
if "atom" in block.samples.names:
split_values = torch.split(
values, [len(system.positions) for system in systems]
)
else:
split_values = torch.split(values, [1 for _ in systems])
new_values = []
ell = (len(block.components[0]) - 1) // 2
for v, transformation, wigner_D_matrix in zip(
split_values, transformations, wigner_D_matrices[ell]
):
is_inverted = torch.det(transformation) < 0
new_v = v.clone()
if is_inverted: # inversion
new_v = new_v * (-1) ** ell * sigma
# fold property dimension in, apply transformation,
# unfold property dimension
new_v = new_v.transpose(1, 2)
new_v = new_v @ wigner_D_matrix.T
new_v = new_v.transpose(1, 2)
new_values.append(new_v)
new_values = torch.concatenate(new_values)
new_block = TensorBlock(
values=new_values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
new_blocks.append(new_block)
return TensorMap(
keys=target_tmap.keys,
blocks=new_blocks,
)
@torch.jit.script # script for speed
def _apply_random_augmentations( # pragma: no cover
systems: List[System],
targets: Dict[str, TensorMap],
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
) -> Tuple[List[System], Dict[str, TensorMap]]:
# Apply the transformations to the systems
new_systems: List[System] = []
for system, transformation in zip(systems, transformations):
new_system = System(
positions=system.positions @ transformation.T,
types=system.types,
cell=system.cell @ transformation.T,
pbc=system.pbc,
)
for nl_options in system.known_neighbor_lists():
old_nl = system.get_neighbor_list(nl_options)
new_system.add_neighbor_list(
nl_options,
TensorBlock(
values=(old_nl.values.squeeze(-1) @ transformation.T).unsqueeze(-1),
samples=old_nl.samples,
components=old_nl.components,
properties=old_nl.properties,
),
)
new_systems.append(new_system)
# Apply the transformation to the targets
new_targets: Dict[str, TensorMap] = {}
for name, target_tmap in targets.items():
is_scalar = False
if len(target_tmap.blocks()) == 1:
if len(target_tmap.block().components) == 0:
is_scalar = True
is_cartesian = False
if len(target_tmap.blocks()) == 1:
if len(target_tmap.block().components) > 0:
if "xyz" in target_tmap.block().components[0].names[0]:
is_cartesian = True
is_spherical = all(
len(block.components) == 1 and block.components[0].names == ["o3_mu"]
for block in target_tmap.blocks()
)
if is_scalar:
# no change for energies
energy_block = TensorBlock(
values=target_tmap.block().values,
samples=target_tmap.block().samples,
components=target_tmap.block().components,
properties=target_tmap.block().properties,
)
if target_tmap.block().has_gradient("positions"):
# transform position gradients:
block = target_tmap.block().gradient("positions")
position_gradients = block.values.squeeze(-1)
split_sizes_forces = [system.positions.shape[0] for system in systems]
split_position_gradients = torch.split(
position_gradients, split_sizes_forces
)
position_gradients = torch.cat(
[
split_position_gradients[i] @ transformations[i].T
for i in range(len(systems))
]
)
energy_block.add_gradient(
"positions",
TensorBlock(
values=position_gradients.unsqueeze(-1),
samples=block.samples,
components=block.components,
properties=block.properties,
),
)
if target_tmap.block().has_gradient("strain"):
# transform strain gradients (rank-2 tensor):
block = target_tmap.block().gradient("strain")
strain_gradients = block.values.squeeze(-1)
split_strain_gradients = torch.split(strain_gradients, 1)
new_strain_gradients = torch.stack(
[
transformations[i]
@ split_strain_gradients[i].squeeze(0)
@ transformations[i].T
for i in range(len(systems))
],
dim=0,
)
energy_block.add_gradient(
"strain",
TensorBlock(
values=new_strain_gradients.unsqueeze(-1),
samples=block.samples,
components=block.components,
properties=block.properties,
),
)
new_targets[name] = TensorMap(
keys=target_tmap.keys,
blocks=[energy_block],
)
elif is_spherical:
new_targets[name] = _apply_wigner_D_matrices(
systems, target_tmap, transformations, wigner_D_matrices
)
elif is_cartesian:
rank = len(target_tmap.block().components)
if rank == 1:
# transform Cartesian vector:
block = target_tmap.block()
vectors = block.values
if "atom" in target_tmap.block().samples.names:
split_vectors = torch.split(
vectors, [len(system.positions) for system in systems]
)
else:
split_vectors = torch.split(vectors, [1 for _ in systems])
new_vectors = []
for v, transformation in zip(split_vectors, transformations):
# fold property dimension in, apply transformation,
# unfold property dimension
new_v = v.transpose(1, 2)
new_v = new_v @ transformation.T
new_v = new_v.transpose(1, 2)
new_vectors.append(new_v)
new_vectors = torch.cat(new_vectors)
new_targets[name] = TensorMap(
keys=target_tmap.keys,
blocks=[
TensorBlock(
values=new_vectors,
samples=block.samples,
components=block.components,
properties=block.properties,
)
],
)
elif rank == 2:
# transform Cartesian rank-2 tensor:
block = target_tmap.block()
tensor = block.values
if "atom" in target_tmap.block().samples.names:
split_tensors = torch.split(
tensor, [len(system.positions) for system in systems]
)
else:
split_tensors = torch.split(tensor, [1 for _ in systems])
new_tensors = []
for tensor, transformation in zip(split_tensors, transformations):
new_tensor = torch.einsum(
"Aa,iabp,bB->iABp", transformation, tensor, transformation.T
)
new_tensors.append(new_tensor)
new_tensors = torch.cat(new_tensors)
new_targets[name] = TensorMap(
keys=target_tmap.keys,
blocks=[
TensorBlock(
values=new_tensors,
samples=block.samples,
components=block.components,
properties=block.properties,
)
],
)
return new_systems, new_targets
def _complex_to_real_spherical_harmonics_transform(ell: int):
# Generates the transformation matrix from complex spherical harmonics
# to real spherical harmonics for a given l.
# Returns a transformation matrix of shape ((2l+1), (2l+1)).
if ell < 0 or not isinstance(ell, int):
raise ValueError("l must be a non-negative integer.")
# The size of the transformation matrix is (2l+1) x (2l+1)
size = 2 * ell + 1
U = np.zeros((size, size), dtype=complex)
for m in range(-ell, ell + 1):
m_index = m + ell # Index in the matrix
if m > 0:
# Real part of Y_{l}^{m}
U[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m
U[m_index, ell - m] = 1 / np.sqrt(2)
elif m < 0:
# Imaginary part of Y_{l}^{|m|}
U[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m
U[m_index, ell - abs(m)] = 1j / np.sqrt(2)
else: # m == 0
# Y_{l}^{0} remains unchanged
U[m_index, ell] = 1
return U
def _scipy_quaternion_to_quaternionic(q_scipy):
# This function convert a quaternion obtained from the scipy library to the format
# used by the quaternionic library.
# Note: 'xyzw' is the format used by scipy.spatial.transform.Rotation
# while 'wxyz' is the format used by quaternionic.
qx, qy, qz, qw = q_scipy
q_quaternion = np.array([qw, qx, qy, qz])
return q_quaternion