Source code for atomate2.forcefields.schemas

"""Schema definitions for force field tasks."""

from __future__ import annotations

from typing import Any, Optional

from emmet.core.utils import ValueEnum
from emmet.core.vasp.calculation import StoreTrajectoryOption
from monty.dev import deprecated
from pydantic import Field
from pymatgen.core import Structure

from atomate2.ase.schemas import AseObject, AseResult, AseStructureTaskDoc, AseTaskDoc
from atomate2.forcefields import MLFF


[docs] @deprecated(replacement=AseResult, deadline=(2025, 1, 1)) class ForcefieldResult(AseResult): """Schema to store outputs; deprecated.""" final_structure: Optional[Structure] = Field( None, description="The structure in the final trajectory frame." )
[docs] def model_post_init(self, __context: Any) -> None: """Populate final_structure attr.""" self.final_structure = getattr( self, "final_structure", self.final_mol_or_struct )
[docs] @deprecated(replacement=AseObject, deadline=(2025, 1, 1)) class ForcefieldObject(ValueEnum): """Types of force-field output data objects.""" TRAJECTORY = "trajectory"
[docs] class ForceFieldTaskDocument(AseStructureTaskDoc): """Document containing information on structure manipulation using a force field.""" forcefield_name: Optional[str] = Field( None, description="name of the interatomic potential used for relaxation.", ) forcefield_version: Optional[str] = Field( "Unknown", description="version of the interatomic potential used for relaxation.", ) dir_name: Optional[str] = Field( None, description="Directory where the force field calculations are performed." ) included_objects: Optional[list[AseObject]] = Field( None, description="list of forcefield objects included with this task document" ) objects: Optional[dict[AseObject, Any]] = Field( None, description="Forcefield objects associated with this task" ) is_force_converged: Optional[bool] = Field( None, description=( "Whether the calculation is converged with respect " "to interatomic forces." ), )
[docs] @classmethod def from_ase_compatible_result( cls, ase_calculator_name: str, result: AseResult, steps: int, relax_kwargs: dict = None, optimizer_kwargs: dict = None, fix_symmetry: bool = False, symprec: float = 1e-2, ionic_step_data: tuple = ( "energy", "forces", "magmoms", "stress", "mol_or_struct", ), store_trajectory: StoreTrajectoryOption = StoreTrajectoryOption.NO, tags: list[str] | None = None, **task_document_kwargs, ) -> ForceFieldTaskDocument: """Create an AseTaskDoc for a task that has ASE-compatible outputs. Parameters ---------- ase_calculator_name : str Name of the ASE calculator used. result : AseResult The output results from the task. fix_symmetry : bool Whether to fix the symmetry of the ions during relaxation. symprec : float Tolerance for symmetry finding in case of fix_symmetry. steps : int Maximum number of ionic steps allowed during relaxation. relax_kwargs : dict Keyword arguments that will get passed to :obj:`Relaxer.relax`. optimizer_kwargs : dict Keyword arguments that will get passed to :obj:`Relaxer()`. ionic_step_data : tuple Which data to save from each ionic step. store_trajectory: whether to set the StoreTrajectoryOption tags : list[str] or None A list of tags for the task. task_document_kwargs : dict Additional keyword args passed to :obj:`.AseTaskDoc()`. """ ase_task_doc = AseTaskDoc.from_ase_compatible_result( ase_calculator_name=ase_calculator_name, result=result, fix_symmetry=fix_symmetry, symprec=symprec, steps=steps, relax_kwargs=relax_kwargs, optimizer_kwargs=optimizer_kwargs, ionic_step_data=ionic_step_data, store_trajectory=store_trajectory, tags=tags, **task_document_kwargs, ) ff_kwargs = { "forcefield_name": task_document_kwargs.get( "forcefield_name", ase_calculator_name ) } # map force field name to its package name model_to_pkg_map = { MLFF.M3GNet: "matgl", MLFF.CHGNet: "chgnet", MLFF.MACE: "mace-torch", MLFF.GAP: "quippy-ase", MLFF.Nequip: "nequip", } if pkg_name := {str(k): v for k, v in model_to_pkg_map.items()}.get( ff_kwargs["forcefield_name"] ): import importlib.metadata ff_kwargs["forcefield_version"] = importlib.metadata.version(pkg_name) return cls.from_ase_task_doc(ase_task_doc, **ff_kwargs)
@property def forcefield_objects(self) -> Optional[dict[AseObject, Any]]: """Alias `objects` attr for backwards compatibility.""" return self.objects