"""Utils for using a force field (aka an interatomic potential)."""
from __future__ import annotations
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
from ase.units import Bohr
from ase.units import GPa as _GPa_to_eV_per_A3
from monty.json import MontyDecoder
if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any
from ase.calculators.calculator import Calculator
from atomate2.ase.schemas import AseResult
_FORCEFIELD_DATA_OBJECTS = ["trajectory", "ionic_steps"]
[docs]
class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
"""Names of ML force fields."""
MACE = "MACE" # This is MACE-MP-0 (medium), deprecated
MACE_MP_0 = "MACE-MP-0"
MACE_MPA_0 = "MACE-MPA-0"
MACE_MP_0B3 = "MACE-MP-0b3"
GAP = "GAP"
M3GNet = "M3GNet"
CHGNet = "CHGNet"
Forcefield = "Forcefield" # default placeholder option
NEP = "NEP"
Nequip = "Nequip"
SevenNet = "SevenNet"
MATPES_R2SCAN = "MatPES-r2SCAN"
MATPES_PBE = "MatPES-PBE"
DeepMD = "DeepMD"
@classmethod
def _missing_(cls, value: Any) -> Any:
"""Allow input of str(MLFF) as valid enum."""
if isinstance(value, str):
value = value.split("MLFF.")[-1]
for member in cls:
if member.name == value:
return member
return None
_DEFAULT_CALCULATOR_KWARGS = {
MLFF.CHGNet: {"stress_weight": _GPa_to_eV_per_A3},
MLFF.M3GNet: {"stress_weight": _GPa_to_eV_per_A3},
MLFF.NEP: {"model_filename": "nep.txt"},
MLFF.GAP: {"args_str": "IP GAP", "param_filename": "gap.xml"},
MLFF.MACE: {"model": "medium"},
MLFF.MACE_MP_0: {"model": "medium"},
MLFF.MACE_MPA_0: {"model": "medium-mpa-0"},
MLFF.MACE_MP_0B3: {"model": "medium-0b3"},
MLFF.MATPES_PBE: {
"architecture": "TensorNet",
"version": "2025.1",
"stress_unit": "eV/A3",
},
MLFF.MATPES_R2SCAN: {
"architecture": "TensorNet",
"version": "2025.1",
"stress_unit": "eV/A3",
},
}
def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
"""
Get the standardized force field name.
Parameters
----------
force_field_name : str or .MLFF
The name of the force field
Returns
-------
str : the name of the forcefield from MLFF
"""
if isinstance(force_field_name, str):
# ensure `force_field_name` uses enum format
if force_field_name in MLFF.__members__:
force_field_name = MLFF[force_field_name]
elif force_field_name in [v.value for v in MLFF]:
force_field_name = MLFF(force_field_name)
force_field_name = str(force_field_name)
if force_field_name in {"MLFF.MACE", "MACE"}:
warnings.warn(
"Because the default MP-trained MACE model is constantly evolving, "
"we no longer recommend using `MACE` or `MLFF.MACE` to specify "
"a MACE model. For reproducibility purposes, specifying `MACE` "
"will still default to MACE-MP-0 (medium), which is identical to "
"specifying `MLFF.MACE_MP_0`.",
category=UserWarning,
stacklevel=2,
)
return force_field_name
[docs]
@dataclass
class ForceFieldMixin:
"""Mix-in class for force-fields.
Attributes
----------
force_field_name : str or MLFF
Name of the forcefield which will be
correctly deserialized/standardized if the forcefield is
a known `MLFF`.
calculator_kwargs : dict = field(default_factory=dict)
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs: dict = field(default_factory=dict)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()
or another final document schema.
"""
force_field_name: str | MLFF = MLFF.Forcefield
calculator_kwargs: dict = field(default_factory=dict)
task_document_kwargs: dict = field(default_factory=dict)
def __post_init__(self) -> None:
"""Ensure that force_field_name is correctly assigned."""
if hasattr(super(), "__post_init__"):
super().__post_init__() # type: ignore[misc]
self.force_field_name = _get_formatted_ff_name(self.force_field_name)
# Pad calculator_kwargs with default values, but permit user to override them
self.calculator_kwargs = {
**_DEFAULT_CALCULATOR_KWARGS.get(
MLFF(self.force_field_name.split("MLFF.")[-1]), {}
),
**self.calculator_kwargs,
}
if not self.task_document_kwargs.get("force_field_name"):
self.task_document_kwargs["force_field_name"] = str(self.force_field_name)
def _run_ase_safe(self, *args, **kwargs) -> AseResult:
if not hasattr(self, "run_ase"):
raise NotImplementedError(
"You must implement a `run_ase` method to use this method."
)
with revert_default_dtype():
return self.run_ase(*args, **kwargs)
@property
def calculator(self) -> Calculator:
"""ASE calculator, can be overwritten by user."""
return ase_calculator(
str(self.force_field_name), # make mypy happy
**self.calculator_kwargs,
)
[docs]
def ase_calculator(
calculator_meta: str | MLFF | dict, **kwargs: Any
) -> Calculator | None:
"""
Create an ASE calculator from a given set of metadata.
Parameters
----------
calculator_meta : str or dict
If a str, should be one of `atomate2.forcefields.MLFF`.
If a dict, should be decodable by `monty.json.MontyDecoder`.
For example, one can also call the CHGNet calculator as follows
```
calculator_meta = {
"@module": "chgnet.model.dynamics",
"@callable": "CHGNetCalculator"
}
```
args : optional args to pass to a calculator
kwargs : optional kwargs to pass to a calculator
Returns
-------
ASE .Calculator
"""
calculator = None
if (
isinstance(calculator_meta, str) and calculator_meta in map(str, MLFF)
) or isinstance(calculator_meta, MLFF):
calculator_name = MLFF(calculator_meta)
if calculator_name == MLFF.CHGNet:
from chgnet.model.dynamics import CHGNetCalculator
calculator = CHGNetCalculator(**kwargs)
elif calculator_name in (MLFF.M3GNet, MLFF.MATPES_R2SCAN, MLFF.MATPES_PBE):
import matgl
from matgl.ext.ase import PESCalculator
if calculator_name == MLFF.M3GNet:
path = kwargs.get("path", "M3GNet-MP-2021.2.8-PES")
elif calculator_name in (MLFF.MATPES_R2SCAN, MLFF.MATPES_PBE):
architecture = kwargs.pop("architecture", "TensorNet")
matpes_version = kwargs.pop("version", "2025.1")
path = f"{architecture}-{calculator_name.value}-v{matpes_version}-PES"
potential = matgl.load_model(path)
calculator = PESCalculator(potential, **kwargs)
elif calculator_name in map(
MLFF, ("MACE", "MACE-MP-0", "MACE-MPA-0", "MACE-MP-0b3")
):
from mace.calculators import MACECalculator, mace_mp
model = kwargs.get("model")
if isinstance(model, str | Path) and Path(model).exists():
model_path = model
device = kwargs.pop("device", None) or "cpu"
kwargs.pop("device", None)
calculator = MACECalculator(
model_paths=model_path,
device=device,
**kwargs,
)
if kwargs.get("dispersion", False):
# See https://github.com/materialsproject/atomate2/issues/1262
# Specifying an explicit model path unsets the dispersio
# Reset it here.
import torch
from ase.calculators.mixing import SumCalculator
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
default_d3_kwargs = {
"damping": "bj",
"xc": "pbe",
"cutoff": 40.0 * Bohr,
"dtype": kwargs.get("default_dtype", torch.get_default_dtype()),
}
for k, v in default_d3_kwargs.items():
if k not in kwargs:
kwargs[k] = v
d3_calc = TorchDFTD3Calculator(device=device, **kwargs)
calculator = SumCalculator([calculator, d3_calc])
else:
calculator = mace_mp(**kwargs)
elif calculator_name == MLFF.GAP:
from quippy.potential import Potential
calculator = Potential(**kwargs)
elif calculator_name == MLFF.NEP:
from calorine.calculators import CPUNEP
calculator = CPUNEP(**kwargs)
elif calculator_name == MLFF.Nequip:
from nequip.ase import NequIPCalculator
calculator = getattr(
NequIPCalculator,
"from_compiled_model"
if hasattr(NequIPCalculator, "from_compiled_model")
else "from_deployed_model",
)(**kwargs)
elif calculator_name == MLFF.SevenNet:
from sevenn.sevennet_calculator import SevenNetCalculator
calculator = SevenNetCalculator(**{"model": "7net-0"} | kwargs)
elif calculator_name == MLFF.DeepMD:
from deepmd.calculator import DP
calculator = DP(**kwargs)
elif isinstance(calculator_meta, dict):
calc_cls = MontyDecoder().process_decoded(calculator_meta)
calculator = calc_cls(**kwargs)
if calculator is None:
raise ValueError(f"Could not create ASE calculator for {calculator_meta}.")
return calculator
[docs]
@contextmanager
def revert_default_dtype() -> Generator[None]:
"""Context manager for torch.default_dtype.
Reverts it to whatever torch.get_default_dtype() was when entering the context.
Originally added for use with MACE(Relax|Static)Maker.
https://github.com/ACEsuit/mace/issues/328
"""
import torch
orig = torch.get_default_dtype()
yield
torch.set_default_dtype(orig)