from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
import numpy as np
from emmet.core.thermo import ThermoDoc, validate_thermo_id
from emmet.core.types.enums import ThermoType
from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType
from pydantic import TypeAdapter
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.core import Element
from mp_api.client.core import BaseRester
from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.settings import DEFAULT_THERMOTYPE
from mp_api.client.core.utils import validate_ids
if TYPE_CHECKING:
from collections.abc import Sequence
from enums import Enum
[docs]
class ThermoRester(BaseRester):
suffix = "materials/thermo"
document_model = ThermoDoc # type: ignore
primary_key = "material_id"
@staticmethod
def _check_thermo_types(thermo_types: Sequence[str | Enum]) -> set[str]:
"""Check if a user has input any invalid thermo types.
Args:
thermo_types (Sequence of str or Enum) : list of thermo types
the user has queried for
phase-diagram tbl has "r2SCAN", not "R2SCAN"
mixing of ThermoType/RunType in emmet -_-
TODO: coerce upstream? allow case-insensitivity in emmet?
Returns:
set of str: validated thermo types
Raises:
ValueError if any invalid thermo types are input
"""
t_types: set[str] = {t if isinstance(t, str) else t.value for t in thermo_types}
t_types = {"r2SCAN" if t == "R2SCAN" else t for t in t_types}
valid_types = {"r2SCAN", *map(str, ThermoType.__members__.values())}
if invalid_types := t_types - valid_types:
raise ValueError(
f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}"
)
return t_types
[docs]
def search(
self,
material_ids: str | list[str] | None = None,
chemsys: str | list[str] | None = None,
energy_above_hull: tuple[float, float] | None = None,
equilibrium_reaction_energy: tuple[float, float] | None = None,
formation_energy: tuple[float, float] | None = None,
formula: str | list[str] | None = None,
is_stable: bool | None = None,
num_elements: tuple[int, int] | None = None,
thermo_ids: list[str] | None = None,
thermo_types: list[ThermoType | str] | None = None,
total_energy: tuple[float, float] | None = None,
uncorrected_energy: tuple[float, float] | None = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
) -> list[ThermoDoc] | list[dict]:
"""Query core thermo docs using a variety of search criteria.
Arguments:
material_ids (str, List[str]): A single Material ID string or list of strings
(e.g., mp-149, [mp-149, mp-13]).
chemsys (str, List[str]): A chemical system or list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
energy_above_hull (Tuple[float,float]): Minimum and maximum energy above the hull in eV/atom to consider.
equilibrium_reaction_energy (Tuple[float,float]): Minimum and maximum equilibrium reaction energy
in eV/atom to consider.
formation_energy (Tuple[float,float]): Minimum and maximum formation energy in eV/atom to consider.
formula (str, List[str]): A formula including anonymized formula
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
(e.g., [Fe2O3, ABO3]).
is_stable (bool): Whether the material is stable.
material_ids (List[str]): List of Materials Project IDs to return data for.
thermo_ids (List[str]): List of thermo IDs to return data for. This is a combination of the Materials
Project ID and thermo type (e.g. mp-149_GGA_GGA+U).
thermo_types (List[ThermoType or str]): List of thermo/run types to return data for (e.g. ThermoType.GGA_GGA_U).
num_elements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider.
total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider.
uncorrected_energy (Tuple[float,float]): Minimum and maximum uncorrected total
energy in eV/atom to consider.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in ThermoDoc to return data for.
Default is material_id and last_updated if all_fields is False.
Returns:
([ThermoDoc], [dict]) List of thermo documents or dictionaries.
"""
query_params: dict = defaultdict(dict)
if formula:
if isinstance(formula, str):
formula = [formula]
query_params.update({"formula": ",".join(formula)})
if chemsys:
if isinstance(chemsys, str):
chemsys = [chemsys]
query_params.update({"chemsys": ",".join(chemsys)})
if material_ids:
if isinstance(material_ids, str):
material_ids = [material_ids]
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
if thermo_ids:
try:
for thermo_id in thermo_ids:
validate_thermo_id(thermo_id)
query_params.update({"thermo_ids": ",".join(thermo_ids)})
except Exception:
raise MPRestError(
f"At least one thermo_id in: {thermo_ids} is invalid."
" Try using the validate_thermo_id function from emmet.core.thermo"
" to test your inputs."
)
if thermo_types:
query_params.update(
{"thermo_types": ",".join(self._check_thermo_types(thermo_types))}
)
if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)
if is_stable is not None:
query_params.update({"is_stable": is_stable})
name_dict = {
"total_energy": "energy_per_atom",
"formation_energy": "formation_energy_per_atom",
"energy_above_hull": "energy_above_hull",
"equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom",
"uncorrected_energy": "uncorrected_energy_per_atom",
}
for param, value in locals().items():
if "energy" in param and value:
query_params.update(
{
f"{name_dict[param]}_min": value[0],
f"{name_dict[param]}_max": value[1],
}
)
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
return super()._search( # type: ignore[return-value]
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params,
)
[docs]
def get_phase_diagram_from_chemsys(
self, chemsys: str, thermo_type: ThermoType | str = DEFAULT_THERMOTYPE
) -> PhaseDiagram:
"""Get a pre-computed phase diagram for a given chemsys.
Arguments:
chemsys (str): A chemical system (e.g. Li-Fe-O)
thermo_type (ThermoType): The thermo type for the phase diagram.
Defaults to ThermoType.GGA_GGA_U_R2SCAN.
Returns:
(PhaseDiagram): Pymatgen phase diagram object.
"""
validated_thermo_type = self._check_thermo_types([thermo_type]).pop()
sorted_chemsys = "-".join(sorted(chemsys.split("-")))
version = self.db_version.replace(".", "-")
pd_lbl, _ = self._get_delta_table(
"materialsproject-build", "objects/phase-diagrams", label="phase_diagrams"
)
query = f"""
SELECT phase_diagram
FROM {pd_lbl}
WHERE chemsys='{sorted_chemsys}'
AND version='{version}'
AND thermo_type='{validated_thermo_type}'
"""
table = self._query_delta_single(query)
as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict")
pd: PhaseDiagram | None = None
if len(pds := TypeAdapter(list[PhaseDiagramType]).validate_python(as_py)) > 0:
pd = pds[0]
# Ensure el_ref keys are Element objects for PDPlotter.
# Ensure qhull_data is a numpy array
# This should be fixed in pymatgen
if pd:
for key, entry in list(pd.el_refs.items()):
if not isinstance(key, str):
break
pd.el_refs[Element(str(key))] = entry
pd.el_refs.pop(key)
if isinstance(pd.qhull_data, list):
pd.qhull_data = np.array(pd.qhull_data)
return pd # type: ignore