from typing import List
import torch
from metatomic.torch import System
[docs]
class LongRangeFeaturizer(torch.nn.Module):
"""A class to compute long-range features starting from short-range features.
:param hypers: Dictionary containing the hyperparameters for the long-range
featurizer.
:param feature_dim: The dimension of the short-range features (which also
corresponds to the number of long-range features that will be returned).
:param neighbor_list_options: A :py:class:`NeighborListOptions` object containing
the neighbor list information for the short-range model.
"""
def __init__(self, hypers, feature_dim, neighbor_list_options):
super(LongRangeFeaturizer, self).__init__()
try:
from torchpme import CoulombPotential
from torchpme.calculators import Calculator, EwaldCalculator, P3MCalculator
except ImportError:
raise ImportError(
"`torch-pme` is required for long-range models. "
"Please install it with `pip install torch-pme`."
)
self.ewald_calculator = EwaldCalculator(
potential=CoulombPotential(
smearing=float(hypers["smearing"]),
exclusion_radius=neighbor_list_options.cutoff,
),
full_neighbor_list=neighbor_list_options.full_list,
lr_wavelength=float(hypers["kspace_resolution"]),
)
self.p3m_calculator = P3MCalculator(
potential=CoulombPotential(
smearing=float(hypers["smearing"]),
exclusion_radius=neighbor_list_options.cutoff,
),
interpolation_nodes=hypers["interpolation_nodes"],
full_neighbor_list=neighbor_list_options.full_list,
mesh_spacing=float(hypers["kspace_resolution"]),
)
self.use_ewald = hypers["use_ewald"]
self.direct_calculator = Calculator(
potential=CoulombPotential(
smearing=None,
exclusion_radius=neighbor_list_options.cutoff,
),
full_neighbor_list=False, # see docs of torch.combinations
)
self.neighbor_list_options = neighbor_list_options
self.charges_map = torch.nn.Linear(feature_dim, feature_dim)
[docs]
def forward(
self,
systems: List[System],
features: torch.Tensor,
neighbor_distances: torch.Tensor,
) -> torch.Tensor:
"""Compute the long-range features for a list of systems.
:param systems: A list of :py:class:`System` objects for which to compute the
long-range features. Each system must contain a neighbor list consistent
with the neighbor list options used to create the class.
:param features: A tensor of short-range features for the systems.
:param neighbor_distances: A tensor of neighbor distances for the systems,
which must be consistent with the neighbor list options used to create the
class.
"""
charges = self.charges_map(features)
last_len_nodes = 0
last_len_edges = 0
long_range_features = []
for system in systems:
system_charges = charges[last_len_nodes : last_len_nodes + len(system)]
last_len_nodes += len(system)
neighbor_list = system.get_neighbor_list(self.neighbor_list_options)
neighbor_indices_system = neighbor_list.samples.view(
["first_atom", "second_atom"]
).values
neighbor_distances_system = neighbor_distances[
last_len_edges : last_len_edges + len(neighbor_indices_system)
]
last_len_edges += len(neighbor_indices_system)
if system.pbc.any() and not system.pbc.all():
raise NotImplementedError(
"Long-range features are not currently supported for systems "
"with mixed periodic and non-periodic boundary conditions."
)
if system.pbc.all(): # periodic
if self.use_ewald and self.training: # use Ewald for training only
potential = self.ewald_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
)
else:
potential = self.p3m_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
)
else: # non-periodic
# compute the distance between all pairs of atoms
neighbor_indices_system = torch.combinations(
torch.arange(len(system), device=system.positions.device), 2
)
neighbor_distances_system = torch.sqrt(
torch.sum(
(
system.positions[neighbor_indices_system[:, 1]]
- system.positions[neighbor_indices_system[:, 0]]
)
** 2,
dim=1,
)
)
potential = self.direct_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
)
long_range_features.append(potential * system_charges)
return torch.concatenate(long_range_features)
[docs]
class DummyLongRangeFeaturizer(torch.nn.Module):
# a dummy class for torchscript
def __init__(self):
super().__init__()
self.use_ewald = True
[docs]
def forward(
self,
systems: List[System],
features: torch.Tensor,
neighbor_distances: torch.Tensor,
) -> torch.Tensor:
return torch.tensor(0)