Source code for metatrain.utils.abc

from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import (
    MetatensorAtomisticModel,
    ModelMetadata,
    ModelOutput,
    System,
)

from metatrain.utils.data.dataset import Dataset, DatasetInfo


[docs] class ModelInterface(torch.nn.Module, metaclass=ABCMeta): """ Abstract base class for a machine learning model in metatrain. All architectures in metatrain must be implemented as sub-class of this class, and implement the corresponding methods. """ def __init__(self): """""" super().__init__() required_attributes = [ "__supported_devices__", "__supported_dtypes__", "__default_metadata__", ] for attribute in required_attributes: if not hasattr(self.__class__, attribute): raise TypeError( f"missing '{attribute}' class attribute for " f"{self.__class__.__name__}" )
[docs] @abstractmethod def forward( self, systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: """ Execute the model for the given ``systems``, computing the requested ``outputs``. .. seealso:: :py:class:`metatensor.torch.atomistic.ModelInterface` for more explanation about the different arguments. """
[docs] @abstractmethod def supported_outputs(self) -> Dict[str, ModelOutput]: """ Get the outputs currently supported by this model. This will likely be the same outputs that are set as this model capabilities in :py:func:`ModelInterface.export`. """
[docs] @abstractmethod def restart(self, dataset_info: DatasetInfo) -> "ModelInterface": """ Update a model to restart training, potentially with different dataset and/or targets. This function is called whenever training restarts, with the same or a different dataset. It enables transfer learning (changing the targets), and fine-tuning (same targets, different datasets) This function should return the updated model, or a new instance of the model able to handle the new dataset. """
[docs] @classmethod @abstractmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], context: Literal["restart", "finetune", "export"], ) -> "ModelInterface": """ Create a model from a checkpoint (i.e. state dictionary). :param checkpoint: Checkpoint's state dictionary. :param context: Context in which to load the model. Possible values are ``"restart"`` when restarting a stopped traininf run, ``"finetune"`` when loading a model for further fine-tuning or transfer learning, and ``"export"`` when loading a model for final export. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context. """
[docs] @abstractmethod def export( self, metadata: Optional[ModelMetadata] = None, ) -> MetatensorAtomisticModel: """ Turn this model into an instance of :py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`, containing the model itself, a definition of the model capabilities and some metadata about the model. :param metadata: additional metadata to add in the model as specified by the user. """
[docs] class TrainerInterface(metaclass=ABCMeta): """ Abstract base class for a model trainer in metatrain. All architectures in metatrain must implement such a trainer, which is responsible for training the model. The trainer must be a be sub-class of this class, and implement the corresponding methods. """ @abstractmethod def __init__(self, train_hypers): """ Create a trainer using the hyper-parameters in ``train_hypers``. """
[docs] @abstractmethod def train( self, model: ModelInterface, dtype: torch.dtype, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): """ Train the ``model`` using the ``train_datasets``. How to train the model is left to this class, using the hyper-parameter given in ``__init__``. :param model: the model to train :param dtype: ``torch.dtype`` used by the data in the datasets :param devices: ``torch.device`` to use for training the model. When training with more than one device (e.g. multi-GPU training), this can contains multiple devices. :param train_datasets: datasets to use to train the model :param val_datasets: datasets to use for model validation :param checkpoint_dir: directory where checkpoints shoudl be saved """
[docs] @abstractmethod def save_checkpoint(self, model, path: Union[str, Path]): """ Save a checkoint of both the ``model`` and trainer state to the given ``path`` """
[docs] @classmethod @abstractmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], train_hypers: Dict[str, Any], context: Literal["restart", "finetune"], ) -> "TrainerInterface": """ Create a trainer instance from data stored in the ``checkpoint``. :param checkpoint: Checkpoint's state dictionary. :param train_hypers: Hyper-parameters for the trainer, as specified by the user. :param context: Context in which to load the model. Possible values are ``"restart"`` when restarting a stopped traininf run, and ``"finetune"`` when loading a model for further fine-tuning or transfer learning. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context. """