Source code for metatrain.utils.scaler

from typing import Dict, List, Union

import metatensor.torch
import numpy as np
import torch
from metatensor.torch import TensorMap
from metatomic.torch import ModelOutput

from .additive import remove_additive
from .data import Dataset, DatasetInfo, TargetInfo, get_all_targets
from .jsonschema import validate
from .per_atom import average_by_num_atoms
from .transfer import systems_and_targets_to_device


[docs] class Scaler(torch.nn.Module): """ A class that scales the targets of regression problems to unit standard deviation. In most cases, this should be used in conjunction with a composition model (that removes the multi-dimensional "mean" across the composition space) and/or other additive models. See the `train_model` method for more details. The scaling is performed per-atom, i.e., in cases where the targets are per-structure, the standard deviation is calculated on the targets divided by the number of atoms in each structure. :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API. :param dataset_info: An object containing information about the dataset, including target quantities and atomic types. """ outputs: Dict[str, ModelOutput] scales: torch.Tensor def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): super().__init__() # `model_hypers` should be an empty dictionary validate( instance=model_hypers, schema={"type": "object", "additionalProperties": False}, ) self.dataset_info = dataset_info self.new_targets: Dict[str, TargetInfo] = dataset_info.targets self.outputs: Dict[str, ModelOutput] = {} # Initially, the scales are empty. They will be expanded as new outputs # are registered with `_add_output`. self.register_buffer("scales", torch.ones((0,), dtype=torch.float64)) self.output_name_to_output_index: Dict[str, int] = {} for target_name, target_info in self.dataset_info.targets.items(): self._add_output(target_name, target_info)
[docs] def train_model( self, datasets: List[Union[Dataset, torch.utils.data.Subset]], additive_models: List[torch.nn.Module], treat_as_additive: bool, ) -> None: """ Calculate the scaling weights for all the targets in the datasets. :param datasets: Dataset(s) to calculate the scaling weights for. :param additive_models: Additive models to be removed from the targets before calculating the statistics. :param treat_as_additive: If True, all per-structure targets (i.e. those that) do not contain an ``atom`` label name, are treated as additive. :raises ValueError: If the provided datasets contain targets unknown to the scaler or if the targets are not treated as additive. """ if not treat_as_additive: raise ValueError( "The Scaler class can currently only be trained by treating targets " "as additive." ) if not isinstance(datasets, list): datasets = [datasets] device = self.scales.device # Fill the scales for each "new" target (i.e. those that do not already # have scales from a previous training run) for target_key in self.new_targets: datasets_with_target = [] for dataset in datasets: if target_key in get_all_targets(dataset): datasets_with_target.append(dataset) if len(datasets_with_target) == 0: raise ValueError( f"Target {target_key} in the model's new capabilities is not " "present in any of the training datasets." ) sum_of_squared_targets = 0.0 total_num_elements = 0 for dataset in datasets_with_target: for sample in dataset: systems = [sample["system"]] targets = {target_key: sample[target_key]} systems, targets = systems_and_targets_to_device( systems, targets, device ) for additive_model in additive_models: target_info_dict = {target_key: self.new_targets[target_key]} targets = remove_additive( systems, targets, additive_model, target_info_dict, ) # calculate standard deviations on per-atom quantities targets = average_by_num_atoms( targets, systems, per_structure_keys=[], ) target_info = self.new_targets[target_key] if ( target_info.quantity == "energy" and "positions" in target_info.gradients ): # special case: here we want to scale with respect to the forces # rather than the energies sum_of_squared_targets += torch.sum( targets[target_key].block().gradient("positions").values ** 2 ).item() total_num_elements += ( targets[target_key] .block() .gradient("positions") .values.numel() ) else: sum_of_squared_targets += sum( torch.sum(block.values**2).item() for block in targets[target_key].blocks() ) total_num_elements += sum( block.values.numel() for block in targets[target_key].blocks() ) self.scales[self.output_name_to_output_index[target_key]] = np.sqrt( sum_of_squared_targets / total_num_elements )
[docs] def restart(self, dataset_info: DatasetInfo) -> "Scaler": # merge old and new dataset info merged_info = self.dataset_info.union(dataset_info) self.new_targets = { key: value for key, value in merged_info.targets.items() if key not in self.dataset_info.targets } # register new outputs for target_name, target in self.new_targets.items(): self._add_output(target_name, target) self.dataset_info = merged_info return self
[docs] def forward( self, outputs: Dict[str, TensorMap], ) -> Dict[str, TensorMap]: """ Scales all the targets in the outputs dictionary back to their original scale. :param outputs: A dictionary of target quantities and their values to be scaled. :raises ValueError: If an output does not have a corresponding scale in the scaler model. """ scaled_outputs: Dict[str, TensorMap] = {} for target_key, target in outputs.items(): if target_key in self.outputs: scale = float( self.scales[self.output_name_to_output_index[target_key]].item() ) scaled_target = metatensor.torch.multiply(target, scale) scaled_outputs[target_key] = scaled_target else: scaled_outputs[target_key] = target return scaled_outputs
def _add_output(self, target_name: str, target_info: TargetInfo) -> None: self.outputs[target_name] = ModelOutput( quantity=target_info.quantity, unit=target_info.unit, per_atom=True, ) self.scales = torch.cat( [self.scales, torch.tensor([1.0], dtype=self.scales.dtype)] ) self.output_name_to_output_index[target_name] = len(self.scales) - 1
[docs] def get_scales_dict(self) -> Dict[str, float]: """ Return a dictionary with the scales for each output and output gradient. :return: A dictionary with the scales for each output and output gradient. These correspond to the standard deviation of the targets in the original dataset. The scales for each output gradient are the same as the corresponding output. """ scales_dict = { output_name: self.scales[output_index].item() for output_name, output_index in self.output_name_to_output_index.items() } # Add gradients if present. They have the same scale as the corresponding output for output_name in list(scales_dict.keys()): gradient_names_for_output = self.dataset_info.targets[output_name].gradients for gradient_name in gradient_names_for_output: scales_dict[output_name + "_" + gradient_name + "_gradients"] = ( scales_dict[output_name] ) return scales_dict
[docs] def remove_scale( targets: Dict[str, TensorMap], scaler: Scaler, ): """ Scale all targets to a standard deviation of one. :param targets: Dictionary containing the targets to be scaled. :param scaler: The scaler used to scale the targets. """ scaled_targets = {} for target_key in targets.keys(): scale = float( scaler.scales[scaler.output_name_to_output_index[target_key]].item() ) scaled_targets[target_key] = metatensor.torch.multiply( targets[target_key], 1.0 / scale ) return scaled_targets