Source code for metatrain.utils.data.writers.metatensor
frompathlibimportPathfromtypingimportDict,Listimporttorchfrommetatensor.torchimportTensorMap,savefrommetatomic.torchimportModelCapabilities,System# note that, although we don't use `systems` and `capabilities`, we need them to# match the writer interface
[docs]defwrite_mts(filename:str,systems:List[System],capabilities:ModelCapabilities,predictions:Dict[str,TensorMap],)->None:"""A metatensor-format prediction writer. Writes the predictions to `.mts` files. :param filename: name of the file to save to. :param systems: structures to be written to the file (not written by this writer). :param: capabilities: capabilities of the model (not used by this writer) :param predictions: prediction values to be written to the file. """filename_base=Path(filename).stemforprediction_name,prediction_tmapinpredictions.items():save(filename_base+"_"+prediction_name+".mts",prediction_tmap.to("cpu").to(torch.float64),)