Source code for metatrain.utils.long_range

from typing import List

import torch
from metatomic.torch import System


[docs] class LongRangeFeaturizer(torch.nn.Module): """A class to compute long-range features starting from short-range features. :param hypers: Dictionary containing the hyperparameters for the long-range featurizer. :param feature_dim: The dimension of the short-range features (which also corresponds to the number of long-range features that will be returned). :param neighbor_list_options: A :py:class:`NeighborListOptions` object containing the neighbor list information for the short-range model. """ def __init__(self, hypers, feature_dim, neighbor_list_options): super(LongRangeFeaturizer, self).__init__() try: from torchpme import CoulombPotential from torchpme.calculators import Calculator, EwaldCalculator, P3MCalculator except ImportError: raise ImportError( "`torch-pme` is required for long-range models. " "Please install it with `pip install torch-pme`." ) self.ewald_calculator = EwaldCalculator( potential=CoulombPotential( smearing=float(hypers["smearing"]), exclusion_radius=neighbor_list_options.cutoff, ), full_neighbor_list=neighbor_list_options.full_list, lr_wavelength=float(hypers["kspace_resolution"]), ) self.p3m_calculator = P3MCalculator( potential=CoulombPotential( smearing=float(hypers["smearing"]), exclusion_radius=neighbor_list_options.cutoff, ), interpolation_nodes=hypers["interpolation_nodes"], full_neighbor_list=neighbor_list_options.full_list, mesh_spacing=float(hypers["kspace_resolution"]), ) self.use_ewald = hypers["use_ewald"] self.direct_calculator = Calculator( potential=CoulombPotential( smearing=None, exclusion_radius=neighbor_list_options.cutoff, ), full_neighbor_list=False, # see docs of torch.combinations ) self.neighbor_list_options = neighbor_list_options self.charges_map = torch.nn.Linear(feature_dim, feature_dim)
[docs] def forward( self, systems: List[System], features: torch.Tensor, neighbor_distances: torch.Tensor, ) -> torch.Tensor: """Compute the long-range features for a list of systems. :param systems: A list of :py:class:`System` objects for which to compute the long-range features. Each system must contain a neighbor list consistent with the neighbor list options used to create the class. :param features: A tensor of short-range features for the systems. :param neighbor_distances: A tensor of neighbor distances for the systems, which must be consistent with the neighbor list options used to create the class. """ charges = self.charges_map(features) last_len_nodes = 0 last_len_edges = 0 long_range_features = [] for system in systems: system_charges = charges[last_len_nodes : last_len_nodes + len(system)] last_len_nodes += len(system) neighbor_list = system.get_neighbor_list(self.neighbor_list_options) neighbor_indices_system = neighbor_list.samples.view( ["first_atom", "second_atom"] ).values neighbor_distances_system = neighbor_distances[ last_len_edges : last_len_edges + len(neighbor_indices_system) ] last_len_edges += len(neighbor_indices_system) if system.pbc.any() and not system.pbc.all(): raise NotImplementedError( "Long-range features are not currently supported for systems " "with mixed periodic and non-periodic boundary conditions." ) if system.pbc.all(): # periodic if self.use_ewald and self.training: # use Ewald for training only potential = self.ewald_calculator.forward( charges=system_charges, cell=system.cell, positions=system.positions, neighbor_indices=neighbor_indices_system, neighbor_distances=neighbor_distances_system, ) else: potential = self.p3m_calculator.forward( charges=system_charges, cell=system.cell, positions=system.positions, neighbor_indices=neighbor_indices_system, neighbor_distances=neighbor_distances_system, ) else: # non-periodic # compute the distance between all pairs of atoms neighbor_indices_system = torch.combinations( torch.arange(len(system), device=system.positions.device), 2 ) neighbor_distances_system = torch.sqrt( torch.sum( ( system.positions[neighbor_indices_system[:, 1]] - system.positions[neighbor_indices_system[:, 0]] ) ** 2, dim=1, ) ) potential = self.direct_calculator.forward( charges=system_charges, cell=system.cell, positions=system.positions, neighbor_indices=neighbor_indices_system, neighbor_distances=neighbor_distances_system, ) long_range_features.append(potential * system_charges) return torch.concatenate(long_range_features)
[docs] class DummyLongRangeFeaturizer(torch.nn.Module): # a dummy class for torchscript def __init__(self): super().__init__() self.use_ewald = True
[docs] def forward( self, systems: List[System], features: torch.Tensor, neighbor_distances: torch.Tensor, ) -> torch.Tensor: return torch.tensor(0)