from typing import List
import ase.neighborlist
import numpy as np
import torch
import vesin
from metatensor.torch import Labels, TensorBlock
from metatomic.torch import (
NeighborListOptions,
System,
)
from .data.system_to_ase import system_to_ase
[docs]
def get_requested_neighbor_lists(
module: torch.nn.Module,
) -> List[NeighborListOptions]:
"""Get the neighbor lists requested by a module and its children.
:param module: The module for which to get the requested neighbor lists.
:return: A list of `NeighborListOptions` objects requested by the module.
"""
requested: List[NeighborListOptions] = []
_get_requested_neighbor_lists_in_place(
module=module,
module_name="",
requested=requested,
)
return requested
def _get_requested_neighbor_lists_in_place(
module: torch.nn.Module,
module_name: str,
requested: List[NeighborListOptions],
):
# copied from
# metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py
# and just removed the length units
if hasattr(module, "requested_neighbor_lists"):
for new_options in module.requested_neighbor_lists():
new_options.add_requestor(module_name)
already_requested = False
for existing in requested:
if existing == new_options:
already_requested = True
for requestor in new_options.requestors():
existing.add_requestor(requestor)
if not already_requested:
requested.append(new_options)
for child_name, child in module.named_children():
_get_requested_neighbor_lists_in_place(
module=child,
module_name=module_name + "." + child_name,
requested=requested,
)
[docs]
def get_system_with_neighbor_lists(
system: System, neighbor_lists: List[NeighborListOptions]
) -> System:
"""Attaches neighbor lists to a `System` object.
:param system: The system for which to calculate neighbor lists.
:param neighbor_lists: A list of `NeighborListOptions` objects,
each of which specifies the parameters for a neighbor list.
:return: The `System` object with the neighbor lists added.
"""
# Convert the system to an ASE atoms object
atoms = system_to_ase(system)
# Compute the neighbor lists
for options in neighbor_lists:
if options not in system.known_neighbor_lists():
neighbor_list = _compute_single_neighbor_list(atoms, options).to(
device=system.device, dtype=system.dtype
)
system.add_neighbor_list(options, neighbor_list)
return system
def _compute_single_neighbor_list(
atoms: ase.Atoms, options: NeighborListOptions
) -> TensorBlock:
# Computes a single neighbor list for an ASE atoms object (as in metatomic.torch)
if np.all(atoms.pbc) or np.all(~atoms.pbc):
nl_i, nl_j, nl_S, nl_D = vesin.ase_neighbor_list(
"ijSD",
atoms,
cutoff=options.cutoff,
)
else:
# this is not implemented in vesin, so we use ASE
nl_i, nl_j, nl_S, nl_D = ase.neighborlist.neighbor_list(
"ijSD",
atoms,
cutoff=options.cutoff,
)
# The pair selection code here below avoids a relatively slow loop over
# all pairs to improve performance
reject_condition = (
# we want a half neighbor list, so drop all duplicated neighbors
(nl_j < nl_i)
| (
(nl_i == nl_j)
& (
# only create pairs with the same atom twice if the pair spans more
# than one unit cell
((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0))
|
# When creating pairs between an atom and one of its periodic images,
# the code generates multiple redundant pairs
# (e.g. with shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
# these. We keep the pair in the positive half plane of shifts.
(
(nl_S.sum(axis=1) < 0)
| (
(nl_S.sum(axis=1) == 0)
& ((nl_S[:, 2] < 0) | ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0)))
)
)
)
)
)
selected = np.logical_not(reject_condition)
n_pairs = np.sum(selected)
if options.full_list:
distances = np.empty((2 * n_pairs, 3), dtype=np.float64)
samples = np.empty((2 * n_pairs, 5), dtype=np.int32)
else:
distances = np.empty((n_pairs, 3), dtype=np.float64)
samples = np.empty((n_pairs, 5), dtype=np.int32)
samples[:n_pairs, 0] = nl_i[selected]
samples[:n_pairs, 1] = nl_j[selected]
samples[:n_pairs, 2:] = nl_S[selected]
distances[:n_pairs] = nl_D[selected]
if options.full_list:
samples[n_pairs:, 0] = nl_j[selected]
samples[n_pairs:, 1] = nl_i[selected]
samples[n_pairs:, 2:] = -nl_S[selected]
distances[n_pairs:] = -nl_D[selected]
distances = torch.from_numpy(distances)
return TensorBlock(
values=distances.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=torch.from_numpy(samples),
),
components=[Labels.range("xyz", 3)],
properties=Labels.range("distance", 1),
)