"""Tools and functions common to all forcefields."""from__future__importannotationsimportwarningsfromenumimportEnumfromtypingimportTYPE_CHECKINGifTYPE_CHECKING:fromtypingimportAny
[docs]classMLFF(Enum):# TODO inherit from StrEnum when 3.11+"""Names of ML force fields."""MACE="MACE"# This is MACE-MP-0 (medium), deprecatedMACE_MP_0="MACE-MP-0"MACE_MPA_0="MACE-MPA-0"MACE_MP_0B3="MACE-MP-0b3"GAP="GAP"M3GNet="M3GNet"CHGNet="CHGNet"Forcefield="Forcefield"# default placeholder optionNEP="NEP"Nequip="Nequip"SevenNet="SevenNet"MATPES_R2SCAN="MatPES-r2SCAN"MATPES_PBE="MatPES-PBE"@classmethoddef_missing_(cls,value:Any)->Any:"""Allow input of str(MLFF) as valid enum."""ifisinstance(value,str):value=value.split("MLFF.")[-1]formemberincls:ifmember.name==value:returnmemberreturnNone
def_get_formatted_ff_name(force_field_name:str|MLFF)->str:""" Get the standardized force field name. Parameters ---------- force_field_name : str or .MLFF The name of the force field Returns ------- str : the name of the forcefield from MLFF """ifisinstance(force_field_name,str):# ensure `force_field_name` uses enum formatifforce_field_nameinMLFF.__members__:force_field_name=MLFF[force_field_name]elifforce_field_namein[v.valueforvinMLFF]:force_field_name=MLFF(force_field_name)force_field_name=str(force_field_name)ifforce_field_namein{"MLFF.MACE","MACE"}:warnings.warn("Because the default MP-trained MACE model is constantly evolving, ""we no longer recommend using `MACE` or `MLFF.MACE` to specify ""a MACE model. For reproducibility purposes, specifying `MACE` ""will still default to MACE-MP-0 (medium), which is identical to ""specifying `MLFF.MACE_MP_0`.",category=UserWarning,stacklevel=2,)returnforce_field_name