Source code for atomate2.forcefields.utils

"""Utils for using a force field (aka an interatomic potential)."""

from __future__ import annotations

import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING

from ase.units import Bohr
from monty.json import MontyDecoder
from typing_extensions import assert_never, deprecated

if TYPE_CHECKING:
    from collections.abc import Callable, Generator
    from typing import Any

    from ase.calculators.calculator import Calculator

    from atomate2.ase.schemas import AseResult

_FORCEFIELD_DATA_OBJECTS = ["trajectory", "ionic_steps"]


[docs] class MLFF(Enum): # TODO inherit from StrEnum when 3.11+ """Names of ML force fields.""" MACE = "MACE" # This is MACE-MP-0 (medium), deprecated MACE_MP_0 = "MACE-MP-0" MACE_MPA_0 = "MACE-MPA-0" MACE_MP_0B3 = "MACE-MP-0b3" GAP = "GAP" M3GNet = "M3GNet" CHGNet = "CHGNet" Forcefield = "Forcefield" # default placeholder option NEP = "NEP" Nequip = "Nequip" SevenNet = "SevenNet" MATPES_R2SCAN = "MatPES-r2SCAN" MATPES_PBE = "MatPES-PBE" DeepMD = "DeepMD" Allegro = "Allegro" FAIRChem = "FAIRChem" MatterSim = "MatterSim" UPET = "UPET" @classmethod def _missing_(cls, value: Any) -> Any: """Allow input of str(MLFF) as valid enum.""" if isinstance(value, str): value = value.split("MLFF.")[-1] for member in cls: if member.name == value: return member return None
_DEFAULT_CALCULATOR_KWARGS: dict[MLFF, Any] = { MLFF.CHGNet: {"stress_unit": "eV/A3"}, MLFF.M3GNet: {"stress_unit": "eV/A3"}, MLFF.NEP: {"model_filename": "nep.txt"}, MLFF.GAP: {"args_str": "IP GAP", "param_filename": "gap.xml"}, MLFF.MACE: {"model": "medium"}, MLFF.MACE_MP_0: {"model": "medium"}, MLFF.MACE_MPA_0: {"model": "medium-mpa-0"}, MLFF.MACE_MP_0B3: {"model": "medium-0b3"}, MLFF.MATPES_PBE: { "architecture": "TensorNet", "version": "2025.1", "stress_unit": "eV/A3", }, MLFF.MATPES_R2SCAN: { "architecture": "TensorNet", "version": "2025.1", "stress_unit": "eV/A3", }, MLFF.FAIRChem: { "predict_unit": {"model_name": "uma-s-1p1"}, "task_name": "omat", }, MLFF.UPET: { "model": "pet-mad-s", "version": "1.5.0", }, } def _get_standardized_mlff(force_field_name: str | MLFF) -> MLFF: """Get the standardized force field name. Parameters ---------- force_field_name : str or .MLFF The name of the force field For str, accept both with and without the `MLFF.` prefix. Returns ------- MLFF: the name of the forcefield """ if isinstance(force_field_name, str): # ensure `force_field_name` uses enum format if force_field_name.startswith("MLFF."): force_field_name = force_field_name.split("MLFF.")[-1] if force_field_name in MLFF.__members__: force_field_name = MLFF[force_field_name] elif force_field_name in [v.value for v in MLFF]: force_field_name = MLFF(force_field_name) else: raise ValueError( f"force_field_name={force_field_name} is not a valid MLFF name." ) if force_field_name == MLFF.MACE: warnings.warn( "Because the default MP-trained MACE model is constantly evolving, " "we no longer recommend using `MACE` or `MLFF.MACE` to specify " "a MACE model. For reproducibility purposes, specifying `MACE` " "will still default to MACE-MP-0 (medium), which is identical to " "specifying `MLFF.MACE_MP_0`.", category=UserWarning, stacklevel=2, ) return force_field_name @deprecated("Use _get_standardized_mlff instead.") def _get_formatted_ff_name(force_field_name: str | MLFF) -> str: """ Get the standardized force field name. Parameters ---------- force_field_name : str or .MLFF The name of the force field Returns ------- str : the name of the forcefield from MLFF """ force_field_name = _get_standardized_mlff(force_field_name) return str(force_field_name)
[docs] @dataclass class ForceFieldMixin: """Mix-in class for force-fields. Attributes ---------- force_field_name : str or MLFF Name of the forcefield which will be correctly deserialized/standardized if the forcefield is a known `MLFF`. calculator_meta : MLFF or dict Actual metadata to instantiate the ASE calculator. calculator_kwargs : dict = field(default_factory=dict) Keyword arguments that will get passed to the ASE calculator. task_document_kwargs: dict = field(default_factory=dict) Additional keyword args passed to :obj:`.ForceFieldTaskDocument() or another final document schema. """ force_field_name: str | MLFF | dict = MLFF.Forcefield calculator_meta: MLFF | dict = field(init=False) calculator_kwargs: dict = field(default_factory=dict) task_document_kwargs: dict = field(default_factory=dict) def __post_init__(self) -> None: """Ensure that force_field_name is correctly assigned.""" if hasattr(super(), "__post_init__"): super().__post_init__() # type: ignore[misc] if isinstance(self.force_field_name, dict): mlff = MLFF.Forcefield # Fallback to placeholder self.calculator_meta = self.force_field_name.copy() else: mlff = _get_standardized_mlff(self.force_field_name) self.calculator_meta = mlff self.force_field_name: str = str(mlff) # Narrow-down type for mypy # Pad calculator_kwargs with default values, but permit user to override them self.calculator_kwargs = { **_DEFAULT_CALCULATOR_KWARGS.get(mlff, {}), **self.calculator_kwargs, } if not self.task_document_kwargs.get("force_field_name"): self.task_document_kwargs["force_field_name"] = self.force_field_name def _run_ase_safe(self, *args, **kwargs) -> AseResult: if not hasattr(self, "run_ase"): raise NotImplementedError( "You must implement a `run_ase` method to use this method." ) with revert_default_dtype(): return self.run_ase(*args, **kwargs) @property def calculator(self) -> Calculator: """ASE calculator, can be overwritten by user.""" return ase_calculator( self.calculator_meta, **self.calculator_kwargs, ) @property def mlff(self) -> MLFF: """The MLFF enum corresponding to the force field name.""" return MLFF(str(self.force_field_name).split("MLFF.")[-1]) @cached_property def ase_calculator_name(self) -> str: """The name of the ASE calculator for schemas.""" if isinstance(self.calculator_meta, MLFF): return str(self.force_field_name) if isinstance(self.calculator_meta, dict): calc_cls = _load_calc_cls(self.calculator_meta) return calc_cls.__name__ assert_never(self.calculator_meta)
[docs] def ase_calculator( calculator_meta: str | MLFF | dict, **kwargs: Any ) -> Calculator | None: """ Create an ASE calculator from a given set of metadata. Parameters ---------- calculator_meta : str or dict If a str, should be one of `atomate2.forcefields.MLFF`. If a dict, should be decodable by `monty.json.MontyDecoder`. For example, one can also call the CHGNet calculator as follows ``` calculator_meta = { "@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator" } ``` args : optional args to pass to a calculator kwargs : optional kwargs to pass to a calculator Returns ------- ASE .Calculator """ calculator = None if ( isinstance(calculator_meta, str) and (calculator_meta in map(str, MLFF) or calculator_meta in MLFF) ) or isinstance(calculator_meta, MLFF): calculator_name = MLFF(calculator_meta) match calculator_name: case MLFF.CHGNet | MLFF.M3GNet | MLFF.MATPES_R2SCAN | MLFF.MATPES_PBE: import matgl match calculator_name: case MLFF.M3GNet: path = kwargs.get("path", "M3GNet-MP-2021.2.8-PES") matgl.config.BACKEND = "DGL" case MLFF.CHGNet: path = kwargs.get("path", "CHGNet-MPtrj-2023.12.1-2.7M-PES") matgl.config.BACKEND = "DGL" warnings.warn( "The CHGNet functionality in atomate2 has been migrated " "from the `chgnet` package to `matgl` to ensure continuing " "support. If you want to use the `chgnet` package, " "`pip install chgnet` and then specify " '`calculator_meta = {"@module": "chgnet.model.dynamics", ' '"@callable": "CHGNetCalculator"}`', stacklevel=2, ) case MLFF.MATPES_R2SCAN | MLFF.MATPES_PBE: path = ( f"{kwargs.pop('architecture', 'TensorNet')}" f"-{calculator_name.value}" f"-v{kwargs.pop('version', '2025.1')}" "-PES" ) matgl.config.BACKEND = "PYG" if matgl.config.BACKEND == "DGL": from matgl.ext._ase_dgl import PESCalculator else: from matgl.ext._ase_pyg import PESCalculator potential = matgl.load_model(path) calculator = PESCalculator(potential, **kwargs) case MLFF.MACE | MLFF.MACE_MP_0 | MLFF.MACE_MPA_0 | MLFF.MACE_MP_0B3: from mace.calculators import MACECalculator, mace_mp model = kwargs.get("model") if isinstance(model, str | Path) and Path(model).exists(): model_path = model device = kwargs.pop("device", None) or "cpu" kwargs.pop("device", None) calculator = MACECalculator( model_paths=model_path, device=device, **kwargs, ) if kwargs.get("dispersion", False): # See https://github.com/materialsproject/atomate2/issues/1262 # Specifying an explicit model path unsets the dispersio # Reset it here. import torch from ase.calculators.mixing import SumCalculator from torch_dftd.torch_dftd3_calculator import ( TorchDFTD3Calculator, ) default_d3_kwargs = { "damping": "bj", "xc": "pbe", "cutoff": 40.0 * Bohr, "dtype": kwargs.get( "default_dtype", torch.get_default_dtype() ), } for k, v in default_d3_kwargs.items(): if k not in kwargs: kwargs[k] = v d3_calc = TorchDFTD3Calculator(device=device, **kwargs) calculator = SumCalculator([calculator, d3_calc]) else: calculator = mace_mp(**kwargs) case MLFF.GAP: from quippy.potential import Potential calculator = Potential(**kwargs) case MLFF.NEP: from calorine.calculators import CPUNEP calculator = CPUNEP(**kwargs) case MLFF.Nequip | MLFF.Allegro: from nequip.ase import NequIPCalculator calculator = getattr( NequIPCalculator, "from_compiled_model" if hasattr(NequIPCalculator, "from_compiled_model") else "from_deployed_model", )(**kwargs) case MLFF.SevenNet: from sevenn.sevennet_calculator import SevenNetCalculator calculator = SevenNetCalculator(**{"model": "7net-0"} | kwargs) case MLFF.DeepMD: from deepmd.calculator import DP calculator = DP(**kwargs) case MLFF.FAIRChem: from fairchem.core import FAIRChemCalculator, pretrained_mlip predict_unit_kwargs = kwargs.pop( "predict_unit", _DEFAULT_CALCULATOR_KWARGS[MLFF.FAIRChem]["predict_unit"], ) calculator = FAIRChemCalculator( pretrained_mlip.get_predict_unit(**predict_unit_kwargs), **{k: v for k, v in kwargs.items() if k != "predict_unit"}, ) case MLFF.MatterSim: from mattersim.forcefield import MatterSimCalculator calculator = MatterSimCalculator(**kwargs) case MLFF.UPET: from upet.calculator import UPETCalculator calculator = UPETCalculator(**kwargs) elif isinstance(calculator_meta, dict): calc_cls = _load_calc_cls(calculator_meta) calculator = calc_cls(**kwargs) if calculator is None: raise ValueError(f"Could not create ASE calculator for {calculator_meta}.") return calculator
def _load_calc_cls( calculator_meta: dict, ) -> type[Calculator] | Callable[..., Calculator]: return MontyDecoder().process_decoded(calculator_meta)
[docs] @contextmanager def revert_default_dtype() -> Generator[None]: """Context manager for torch.default_dtype. Reverts it to whatever torch.get_default_dtype() was when entering the context. Originally added for use with MACE(Relax|Static)Maker. https://github.com/ACEsuit/mace/issues/328 """ import torch orig = torch.get_default_dtype() yield torch.set_default_dtype(orig)
def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None: """Get the package name for a given force field. Parameters ---------- calculator_meta : MLFF or JSONable dict The calculator metadata used to load the calculator, or an MLFF enum. Returns ------- str or None: The package name of the force field if it could be identified, None otherwise. """ if isinstance(calculator_meta, MLFF): # map force field name to its package name match calculator_meta: case MLFF.Allegro | MLFF.Nequip: ff_pkg = "nequip" case MLFF.CHGNet | MLFF.M3GNet | MLFF.MATPES_PBE | MLFF.MATPES_R2SCAN: ff_pkg = "matgl" case MLFF.DeepMD: ff_pkg = "deepmd-kit" case MLFF.FAIRChem: ff_pkg = "fairchem.core" case MLFF.GAP: ff_pkg = "quippy-ase" case MLFF.MACE | MLFF.MACE_MP_0 | MLFF.MACE_MPA_0 | MLFF.MACE_MP_0B3: ff_pkg = "mace-torch" case MLFF.MatterSim: ff_pkg = "mattersim" case MLFF.NEP: ff_pkg = "calorine" case MLFF.SevenNet: ff_pkg = "sevenn" case MLFF.UPET: ff_pkg = "upet" case _: ff_pkg = None return ff_pkg if isinstance(calculator_meta, dict): calc_cls = _load_calc_cls(calculator_meta) return calc_cls.__module__.split(".", 1)[0] assert_never(calculator_meta)