Source code for metatrain.utils.data.readers.ase

import logging
import warnings
from pathlib import PurePath
from typing import IO, List, Tuple, Union

import ase.io
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import System, systems_to_torch
from omegaconf import DictConfig

from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info


[docs] def read(filename: Union[str, PurePath, IO], *args, **kwargs) -> List[ase.Atoms]: """Wrapper around the :func:`ase.io.read` function. The wrapper provides a more informative error message in case of failure. Additionally, it will make the keys ``"energy"``, ``"forces"`` and ``"stress"`` available from the calculator and the info/arrays dictionary. .. warning:: Lists of atoms read with this function can NOT be written back to a file with :func:`ase.io.write` because of the duplicated keys. :param filename: Name of the file to read from or a file descriptor. :param args: additional positional arguments for :func:`ase.io.read` :param kwargs: additional keyword arguments for :func:`ase.io.read` :returns: A list of atoms """ try: frames = ase.io.read(filename, *args, **kwargs) except Exception as e: raise ValueError(f"Failed to read '{filename}' with ASE: {e}") from e # allow access of "special" keys from calculator and `info`/`arrays` dictionary for atoms in frames: if hasattr(atoms, "calc") and atoms.calc is not None: results = atoms.calc.results if "energy" in results: atoms.info["energy"] = results["energy"] if "forces" in results: atoms.arrays["forces"] = results["forces"] if "stress" in results: atoms.info["stress"] = voigt_6_to_full_3x3_stress(results["stress"]) return frames
[docs] def read_systems(filename: str) -> List[System]: """Store system informations using ase. :param filename: name of the file to read :returns: A list of systems """ return systems_to_torch(read(filename, ":"), dtype=torch.float64)
def _read_energy_ase(filename: str, key: str) -> List[TensorBlock]: """Store energy information in a List of :class:`metatensor.TensorBlock`. :param filename: name of the file to read :param key: target value key name to be parsed from the file. :returns: TensorMap containing the energies """ frames = read(filename, ":") properties = Labels("energy", torch.tensor([[0]])) blocks = [] for i_system, atoms in enumerate(frames): if key not in atoms.info: raise ValueError( f"energy key {key!r} was not found in system {filename!r} at index " f"{i_system}" ) values = torch.tensor([[atoms.info[key]]], dtype=torch.float64) samples = Labels(["system"], torch.tensor([[i_system]])) block = TensorBlock( values=values, samples=samples, components=[], properties=properties, ) blocks.append(block) return blocks def _read_forces_ase(filename: str, key: str = "forces") -> List[TensorBlock]: """Store force information in a List of :class:`metatensor.TensorBlock` which can be used as ``position`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file. :returns: TensorMap containing the forces """ frames = read(filename, ":") components = [Labels(["xyz"], torch.arange(3).reshape(-1, 1))] properties = Labels("energy", torch.tensor([[0]])) blocks = [] for i_system, atoms in enumerate(frames): if key not in atoms.arrays: raise ValueError( f"forces key {key!r} was not found in system {filename!r} at index " f"{i_system}" ) # We store forces as positions gradients which means we invert the sign values = -torch.tensor(atoms.arrays[key], dtype=torch.float64) values = values.reshape(-1, 3, 1) samples = Labels( ["sample", "system", "atom"], torch.tensor([[0, i_system, a] for a in range(len(values))]), ) block = TensorBlock( values=values, samples=samples, components=components, properties=properties, ) blocks.append(block) return blocks def _read_virial_ase(filename: str, key: str = "virial") -> List[TensorBlock]: """Store virial information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file :returns: TensorMap containing the virial """ return _read_virial_stress_ase(filename=filename, key=key, is_virial=True) def _read_stress_ase(filename: str, key: str = "stress") -> List[TensorBlock]: """Store stress information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file :returns: TensorMap containing the stress """ return _read_virial_stress_ase(filename=filename, key=key, is_virial=False) def _read_virial_stress_ase( filename: str, key: str, is_virial: bool = True ) -> List[TensorBlock]: frames = read(filename, ":") samples = Labels(["sample"], torch.tensor([[0]])) components = [ Labels(["xyz_1"], torch.arange(3).reshape(-1, 1)), Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), ] properties = Labels("energy", torch.tensor([[0]])) blocks = [] for i_system, atoms in enumerate(frames): if key not in atoms.info: target_name = "virial" if is_virial else "stress" raise ValueError( f"{target_name} key {key!r} was not found in system {filename!r} at " f"index {i_system}" ) values = torch.tensor(atoms.info[key].tolist(), dtype=torch.float64) if values.shape == (9,): warnings.warn( "Found 9-long numerical vector for the stress/virial in system " f"{i_system}. Assume a row major format for the conversion into a " "3 x 3 matrix.", stacklevel=2, ) elif values.shape != (3, 3): raise ValueError( f"Values in system {i_system} has shape {values.shape}. " "Stress/virial must be a 3 x 3 matrix or a 9-long numerical vector." ) values = values.reshape(-1, 3, 3, 1) if is_virial: values *= -1 else: # is stress if atoms.cell.volume == 0: raise ValueError( f"system {i_system} has zero cell vectors. Stress can only " "be used if cell is non zero." ) values *= atoms.cell.volume block = TensorBlock( values=values, samples=samples, components=components, properties=properties, ) blocks.append(block) return blocks
[docs] def read_energy(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: target_key = target["key"] blocks = _read_energy_ase( filename=target["read_from"], key=target["key"], ) add_position_gradients = False if target["forces"]: try: position_gradients = _read_forces_ase( filename=target["forces"]["read_from"], key=target["forces"]["key"], ) except Exception: logging.warning(f"No forces found in section {target_key!r}.") add_position_gradients = False else: logging.info( f"Forces found in section {target_key!r}, " "we will use this gradient to train the model" ) for block, position_gradient in zip(blocks, position_gradients): block.add_gradient(parameter="positions", gradient=position_gradient) add_position_gradients = True if target["stress"] and target["virial"]: raise ValueError("Cannot use stress and virial at the same time") add_strain_gradients = False if target["stress"]: try: strain_gradients = _read_stress_ase( filename=target["stress"]["read_from"], key=target["stress"]["key"], ) except Exception: logging.warning(f"No stress found in section {target_key!r}.") add_strain_gradients = False else: logging.info( f"Stress found in section {target_key!r}, " "we will use this gradient to train the model" ) for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) add_strain_gradients = True if target["virial"]: try: strain_gradients = _read_virial_ase( filename=target["virial"]["read_from"], key=target["virial"]["key"], ) except Exception: logging.warning(f"No virial found in section {target_key!r}.") add_strain_gradients = False else: logging.info( f"Virial found in section {target_key!r}, " "we will use this gradient to train the model" ) for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) add_strain_gradients = True tensor_map_list = [ TensorMap( keys=Labels(["_"], torch.tensor([[0]])), blocks=[block], ) for block in blocks ] target_info = get_energy_target_info( target, add_position_gradients, add_strain_gradients ) return tensor_map_list, target_info
[docs] def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: filename = target["read_from"] frames = read(filename, ":") # we don't allow ASE to read spherical tensors with more than one irrep, # otherwise it's a mess if ( isinstance(target["type"], DictConfig) and next(iter(target["type"].keys())) == "spherical" ): irreps = target["type"]["spherical"]["irreps"] if len(irreps) > 1: raise ValueError( "The metatrain ASE reader does not support reading " "spherical tensors with more than one irreducible " "representation. Please use the metatensor reader." ) target_info = get_generic_target_info(target) components = target_info.layout.block().components properties = target_info.layout.block().properties shape_after_samples = target_info.layout.block().shape[1:] per_atom = target_info.per_atom keys = target_info.layout.keys target_key = target["key"] tensor_maps = [] for i_system, atoms in enumerate(frames): if (per_atom and target_key not in atoms.arrays) or ( not per_atom and target_key not in atoms.info ): raise ValueError( f"Target key {target_key!r} was not found in system {filename!r} at " f"index {i_system}" ) if per_atom: data = atoms.arrays[target_key] else: data = atoms.info[target_key] # here we reshape to allow for more flexibility; this is actually # necessary for the `arrays`, which are stored in a 2D array values = torch.tensor(data, dtype=torch.float64).reshape( [-1] + shape_after_samples ) samples = ( Labels( ["system", "atom"], torch.tensor([[i_system, a] for a in range(len(values))]), ) if per_atom else Labels( ["system"], torch.tensor([[i_system]]), ) ) block = TensorBlock( values=values, samples=samples, components=components, properties=properties, ) tensor_map = TensorMap( keys=keys, blocks=[block], ) tensor_maps.append(tensor_map) return tensor_maps, target_info