Source code for metatrain.utils.data.readers.readers
importimportlibimportwarningsfrompathlibimportPathfromtypingimportDict,List,Optional,Tuplefrommetatensor.torchimportTensorMapfrommetatomic.torchimportSystemfromomegaconfimportDictConfigfrom..target_infoimportTargetInfoAVAILABLE_READERS=["ase","metatensor"]""":py:class:`list`: list containing all implemented reader libraries"""DEFAULT_READER={".xyz":"ase",".extxyz":"ase",".mts":"metatensor"}""":py:class:`dict`: dictionary mapping file extensions to a default reader"""
[docs]defread_systems(filename:str,reader:Optional[str]=None,)->List[System]:"""Read system informations from a file. :param filename: name of the file to read :param reader: reader library for parsing the file. If :py:obj:`None` the library is is tried to determined from the file extension. :param dtype: desired data type of returned tensor :returns: list of systems determined from the file extension. :returns: list of systems stored in double precision """ifreaderisNone:try:file_suffix=Path(filename).suffixreader=DEFAULT_READER[file_suffix]exceptKeyError:raiseValueError(f"File extension {file_suffix!r} is not linked to a default reader ""library. You can try reading it by setting a specific 'reader' from "f"the known ones: {', '.join(AVAILABLE_READERS)} ")try:reader_mod=importlib.import_module(name=f".{reader}",package="metatrain.utils.data.readers")exceptImportError:raiseValueError(f"Reader library {reader!r} not supported. Choose from "f"{', '.join(AVAILABLE_READERS)}")try:reader_met=reader_mod.read_systemsexceptAttributeError:raiseValueError(f"Reader library {reader!r} cannot read systems."f"You can try with other readers: {AVAILABLE_READERS}")systems=reader_met(filename)# elements in data are `torch.ScriptObject`s and their `dtype` is an integer.# A C++ double/torch.float64 is `7` according to# https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/c10/core/ScalarType.h#L54-L93ifnotall(s.dtype==7forsinsystems):raiseValueError("The loaded systems are not in double precision.")returnsystems
[docs]defread_targets(conf:DictConfig,)->Tuple[Dict[str,List[TensorMap]],Dict[str,TargetInfo]]:"""Reading all target information from a fully expanded config. To get such a config you can use :func:`expand_dataset_config <metatrain.utils.omegaconf.expand_dataset_config>`. All targets are stored in double precision. This function uses subfunctions like :func:`read_energy` to parse the requested target quantity. Currently only `energy` is a supported target property. But, within the `energy` section gradients such as `forces`, the `stress` or the `virial` can be added. Other gradients are silently ignored. :param conf: config containing the keys for what should be read. :returns: Dictionary containing a list of TensorMaps for each target section in the config as well as a ``Dict[str, TargetInfo]`` object containing the metadata of the targets. :raises ValueError: if the target name is not valid. Valid target names are those that either start with ``mtt::`` or those that are in the list of standard outputs of ``metatomic`` (see https://docs.metatensor.org/metatomic/latest/outputs/) """target_dictionary={}target_info_dictionary={}standard_outputs_list=["energy","non_conservative_forces","non_conservative_stress",]fortarget_key,targetinconf.items():is_standard_target=target_keyinstandard_outputs_listifnotis_standard_targetandnottarget_key.startswith("mtt::"):iftarget_key.lower()in["force","forces","virial","stress"]:warnings.warn(f"{target_key!r} should not be its own top-level target, ""but rather a sub-section of the 'energy' target",stacklevel=2,)else:raiseValueError(f"Target name ({target_key}) must either be one of "f"{standard_outputs_list} or start with `mtt::`.")if("force"intarget_key.lower()or"virial"intarget_key.lower()or"stress"intarget_key.lower()):warnings.warn(f"the name of {target_key!r} resembles to a gradient of ""energies; it should probably not be its own top-level target, ""but rather a gradient sub-section of a target with the ""`energy` quantity",stacklevel=2,)is_energy=((target["quantity"]=="energy")and(nottarget["per_atom"])andtarget["num_subtargets"]==1andtarget["type"]=="scalar")energy_or_generic="energy"ifis_energyelse"generic"reader=target["reader"]filename=target["read_from"]ifreaderisNone:try:file_suffix=Path(filename).suffixreader=DEFAULT_READER[file_suffix]exceptKeyError:raiseValueError(f"File extension {file_suffix!r} is not linked to a default reader ""library. You can try reading it by setting a specific 'reader' "f"from the known ones: {', '.join(AVAILABLE_READERS)} ")try:reader_mod=importlib.import_module(name=f".{reader}",package="metatrain.utils.data.readers")exceptImportError:raiseValueError(f"Reader library {reader!r} not supported. Choose from "f"{', '.join(AVAILABLE_READERS)}")try:reader_met=getattr(reader_mod,f"read_{energy_or_generic}")exceptAttributeError:raiseValueError(f"Reader library {reader!r} cannot read {target!r}."f"You can try with other readers: {AVAILABLE_READERS}")targets_as_list_of_tensor_maps,target_info=reader_met(target)# elements in data are `torch.ScriptObject`s and their `dtype` is an integer.# A C++ double/torch.float64 is `7` according to# https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/c10/core/ScalarType.h#L54-L93ifnotall(t.dtype==7fortintargets_as_list_of_tensor_maps):raiseValueError("The loaded targets are not in double precision.")target_dictionary[target_key]=targets_as_list_of_tensor_mapstarget_info_dictionary[target_key]=target_inforeturntarget_dictionary,target_info_dictionary