"""Utils for using a force field (aka an interatomic potential)."""from__future__importannotationsimportjsonfromcontextlibimportcontextmanagerfromtypingimportTYPE_CHECKINGfrommonty.jsonimportMontyDecoderfromatomate2.forcefieldsimportMLFFifTYPE_CHECKING:fromcollections.abcimportGeneratorfromtypingimportAnyfromase.calculators.calculatorimportCalculator
[docs]defase_calculator(calculator_meta:str|dict,**kwargs:Any)->Calculator|None:""" Create an ASE calculator from a given set of metadata. Parameters ---------- calculator_meta : str or dict If a str, should be one of `atomate2.forcefields.MLFF`. If a dict, should be decodable by `monty.json.MontyDecoder`. For example, one can also call the CHGNet calculator as follows ``` calculator_meta = { "@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator" } ``` args : optional args to pass to a calculator kwargs : optional kwargs to pass to a calculator Returns ------- ASE .Calculator """calculator=Noneifisinstance(calculator_meta,str|MLFF)andcalculator_metainmap(str,MLFF):calculator_name=MLFF(calculator_meta.split("MLFF.")[-1])ifcalculator_name==MLFF.CHGNet:fromchgnet.model.dynamicsimportCHGNetCalculatorcalculator=CHGNetCalculator(**kwargs)elifcalculator_name==MLFF.M3GNet:importmatglfrommatgl.ext.aseimportPESCalculatorpath=kwargs.get("path","M3GNet-MP-2021.2.8-PES")potential=matgl.load_model(path)calculator=PESCalculator(potential,**kwargs)elifcalculator_name==MLFF.MACE:frommace.calculatorsimportmace_mpcalculator=mace_mp(**kwargs)elifcalculator_name==MLFF.GAP:fromquippy.potentialimportPotentialcalculator=Potential(**kwargs)elifcalculator_name==MLFF.NEP:fromcalorine.calculatorsimportCPUNEPcalculator=CPUNEP(**kwargs)elifcalculator_name==MLFF.Nequip:fromnequip.aseimportNequIPCalculatorcalculator=NequIPCalculator.from_deployed_model(**kwargs)elifcalculator_name==MLFF.SevenNet:fromsevenn.sevennet_calculatorimportSevenNetCalculatorcalculator=SevenNetCalculator(**{"model":"7net-0"}|kwargs)elifisinstance(calculator_meta,dict):calc_cls=MontyDecoder().decode(json.dumps(calculator_meta))calculator=calc_cls(**kwargs)ifcalculatorisNone:raiseValueError(f"Could not create ASE calculator for {calculator_meta}.")returncalculator
[docs]@contextmanagerdefrevert_default_dtype()->Generator[None,None,None]:"""Context manager for torch.default_dtype. Reverts it to whatever torch.get_default_dtype() was when entering the context. Originally added for use with MACE(Relax|Static)Maker. https://github.com/ACEsuit/mace/issues/328 """importtorchorig=torch.get_default_dtype()yieldtorch.set_default_dtype(orig)