import warnings
from typing import Dict, List, Union
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import (
AtomisticModel,
ModelEvaluationOptions,
ModelOutput,
System,
is_atomistic_model,
)
from .data import TargetInfo
from .output_gradient import compute_gradient
[docs]
def evaluate_model(
model: Union[
torch.nn.Module,
AtomisticModel,
torch.jit.RecursiveScriptModule,
],
systems: List[System],
targets: Dict[str, TargetInfo],
is_training: bool,
check_consistency: bool = False,
) -> Dict[str, TensorMap]:
"""
Evaluate the model (in training or exported) on a set of requested targets.
:param model: The model to use. This can either be a model in training
(``torch.nn.Module``) or an exported model
(``torch.jit.RecursiveScriptModule``).
:param systems: The systems to use.
:param targets: The names of the targets to evaluate (keys), along with their
associated gradients (values).
:param is_training: Whether the model is being computed during training.
:returns: The predictions of the model for the requested targets.
"""
# ignore warnings about gradients
warnings.filterwarnings(
action="ignore",
message="This system's positions or cell requires grad, but the neighbors",
)
model_outputs = _get_supported_outputs(model)
# Assert that all targets are within the model's supported outputs:
if not set(targets.keys()).issubset(model_outputs.keys()):
raise ValueError("Not all targets are within the model's supported outputs")
# Find if there are any energy targets that require gradients:
energy_targets = []
energy_targets_that_require_position_gradients = []
energy_targets_that_require_strain_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if model_outputs[target_name].quantity == "energy":
energy_targets.append(target_name)
# Check if the energy requires gradients:
if "positions" in targets[target_name].gradients:
energy_targets_that_require_position_gradients.append(target_name)
if "strain" in targets[target_name].gradients:
energy_targets_that_require_strain_gradients.append(target_name)
new_systems = []
strains = []
for system in systems:
new_system, strain = _prepare_system(
system,
positions_grad=len(energy_targets_that_require_position_gradients) > 0,
strain_grad=len(energy_targets_that_require_strain_gradients) > 0,
check_consistency=check_consistency,
)
new_systems.append(new_system)
strains.append(strain)
systems = new_systems
# Based on the keys of the targets, get the outputs of the model:
model_outputs = _get_model_outputs(model, systems, targets, check_consistency)
for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
target_requires_pos_gradients = (
energy_target in energy_targets_that_require_position_gradients
)
target_requires_strain_gradients = (
energy_target in energy_targets_that_require_strain_gradients
)
if target_requires_pos_gradients and target_requires_strain_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
[system.positions for system in systems] + strains,
is_training=is_training,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient(
"positions", _position_gradients_to_block(gradients[: len(systems)])
)
new_block.add_gradient(
"strain",
_strain_gradients_to_block(gradients[len(systems) :]),
)
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
elif target_requires_pos_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
[system.positions for system in systems],
is_training=is_training,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient("positions", _position_gradients_to_block(gradients))
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
elif target_requires_strain_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
strains,
is_training=is_training,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient("strain", _strain_gradients_to_block(gradients))
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
else:
pass
return model_outputs
def _position_gradients_to_block(gradients_list):
"""Convert a list of position gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""
# `gradients` consists of a list of tensors where the second dimension is 3
gradients = torch.concatenate(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension
samples = Labels(
names=["sample", "atom"],
values=torch.stack(
[
torch.concatenate(
[
torch.tensor([i] * len(system))
for i, system in enumerate(gradients_list)
]
),
torch.concatenate(
[torch.arange(len(system)) for system in gradients_list]
),
],
dim=1,
),
)
components = [
Labels(
names=["xyz"],
values=torch.tensor([[0], [1], [2]]),
)
]
return TensorBlock(
values=gradients,
samples=samples.to(gradients.device),
components=[c.to(gradients.device) for c in components],
properties=Labels("energy", torch.tensor([[0]])).to(gradients.device),
)
def _strain_gradients_to_block(gradients_list):
"""Convert a list of strain gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""
gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension
samples = Labels(
names=["sample"], values=torch.arange(len(gradients_list)).unsqueeze(-1)
)
components = [
Labels(
names=["xyz_1"],
values=torch.tensor([[0], [1], [2]]),
),
Labels(
names=["xyz_2"],
values=torch.tensor([[0], [1], [2]]),
),
]
return TensorBlock(
values=gradients,
samples=samples.to(gradients.device),
components=[c.to(gradients.device) for c in components],
properties=Labels("energy", torch.tensor([[0]])).to(gradients.device),
)
def _get_supported_outputs(
model: Union[torch.nn.Module, torch.jit.RecursiveScriptModule],
):
if is_atomistic_model(model):
return model.capabilities().outputs
else:
return model.supported_outputs()
def _get_model_outputs(
model: Union[
torch.nn.Module,
AtomisticModel,
torch.jit.RecursiveScriptModule,
],
systems: List[System],
targets: Dict[str, TargetInfo],
check_consistency: bool,
) -> Dict[str, TensorMap]:
if is_atomistic_model(model):
# put together an EvaluationOptions object
options = ModelEvaluationOptions(
length_unit="", # this is only needed for unit conversions in MD engines
outputs={
key: ModelOutput(
quantity=value.quantity, unit=value.unit, per_atom=value.per_atom
)
for key, value in targets.items()
},
)
return model(systems, options, check_consistency=check_consistency)
else:
return model(
systems,
{
key: ModelOutput(
quantity=value.quantity, unit=value.unit, per_atom=value.per_atom
)
for key, value in targets.items()
},
)
@torch.jit.script
def _prepare_system( # pragma: no cover
system: System, positions_grad: bool, strain_grad: bool, check_consistency: bool
):
"""
Prepares a system for gradient calculation.
"""
if strain_grad:
strain = torch.eye(
3,
dtype=system.cell.dtype,
device=system.cell.device,
).requires_grad_(True)
new_system = System(
positions=system.positions @ strain,
cell=system.cell @ strain,
types=system.types,
pbc=system.pbc,
)
else:
if positions_grad:
new_system = System(
positions=system.positions.detach().clone().requires_grad_(True),
cell=system.cell,
types=system.types,
pbc=system.pbc,
)
strain = None
else:
new_system = System(
positions=system.positions,
cell=system.cell,
types=system.types,
pbc=system.pbc,
)
strain = None
for nl_options in system.known_neighbor_lists():
nl = system.get_neighbor_list(nl_options)
new_system.add_neighbor_list(nl_options, nl)
return new_system, strain