Source code for metatrain.utils.external_naming

from typing import Dict, Union

from metatomic.torch import ModelOutput


[docs] def to_external_name( internal_name: str, quantities: Union[Dict[str, ModelOutput]] ) -> str: """Converts internal names to external names. Very often, the "common" names for quantities are different from the internal names used in the code. Two important examples are forces and virials, which are referred to as energy_positions_gradients and energy_strain_gradients, respectively, in the code. This function converts an internal name to an external name. :param internal_name: An internal name to convert. :param quantities: A dictionary of physical quantities, either as :py:class:`TargetInfo` objects or as :py:class:`ModelOutput` objects. :return: The name for external use. """ if internal_name.endswith("_positions_gradients"): base_name = internal_name.replace("_positions_gradients", "") if quantities[base_name].quantity == "energy": if base_name == "energy": # we treat "energy" as a special case external_name = "forces" else: external_name = f"forces[{base_name}]" else: external_name = internal_name elif internal_name.endswith("_strain_gradients"): base_name = internal_name.replace("_strain_gradients", "") if quantities[base_name].quantity == "energy": if base_name == "energy": external_name = "virial" else: external_name = f"virial[{base_name}]" else: external_name = internal_name else: external_name = internal_name return external_name
[docs] def to_internal_name(external_name: str) -> str: """Converts an external names to internal names. This function is the inverse of :func:`to_external_names`. :param external_names: A list of names to convert. :return: The list of names for internal use. """ if external_name == "forces": internal_name = "energy_positions_gradients" elif external_name.startswith("forces[") and external_name.endswith("]"): base_name = external_name[7:-1] internal_name = f"{base_name}_positions_gradients" elif external_name == "virial": internal_name = "energy_strain_gradients" elif external_name.startswith("virial[") and external_name.endswith("]"): base_name = external_name[7:-1] internal_name = f"{base_name}_strain_gradients" else: internal_name = external_name return internal_name