from __future__ import annotations
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING
from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos
from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc
from emmet.core.mpid import AlphaID
from emmet.core.vasp.calc_types.enums import RunType
from pymatgen.analysis.magnetism.analyzer import Ordering
from pymatgen.core.periodic_table import Element
from pymatgen.electronic_structure.bandstructure import (
BandStructure,
BandStructureSymmLine,
)
from pymatgen.electronic_structure.core import OrbitalType, Spin
from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.utils import validate_ids
if TYPE_CHECKING:
from pymatgen.electronic_structure.dos import Dos
[docs]
class ElectronicStructureRester(BaseRester):
suffix = "materials/electronic_structure"
document_model = ElectronicStructureDoc # type: ignore
primary_key = "material_id"
[docs]
def search_electronic_structure_docs(self, *args, **kwargs): # pragma: no cover
"""Deprecated."""
warnings.warn(
"MPRester.electronic_structure.search_electronic_structure_docs is deprecated. "
"Please use MPRester.electronic_structure.search instead.",
DeprecationWarning,
stacklevel=2,
)
return self.search(*args, **kwargs)
[docs]
def search(
self,
material_ids: str | list[str] | None = None,
band_gap: tuple[float, float] | None = None,
chemsys: str | list[str] | None = None,
efermi: tuple[float, float] | None = None,
elements: list[str] | None = None,
exclude_elements: list[str] | None = None,
formula: str | list[str] | None = None,
is_gap_direct: bool | None = None,
is_metal: bool | None = None,
magnetic_ordering: Ordering | None = None,
num_elements: tuple[int, int] | None = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
):
"""Query electronic structure docs using a variety of search criteria.
Arguments:
material_ids (str, List[str]): A single Material ID string or list of strings
(e.g., mp-149, [mp-149, mp-13]).
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
chemsys (str, List[str]): A chemical system or list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
elements (List[str]): A list of elements.
exclude_elements (List[str]): A list of elements to exclude.
formula (str, List[str]): A formula including anonymized formula
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
(e.g., [Fe2O3, ABO3]).
is_gap_direct (bool): Whether the material has a direct band gap.
is_metal (bool): Whether the material is considered a metal.
magnetic_ordering (Ordering): Magnetic ordering of the material.
num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
Default is material_id and last_updated if all_fields is False.
Returns:
([ElectronicStructureDoc]) List of electronic structure documents
"""
query_params: dict = defaultdict(dict)
if material_ids:
if isinstance(material_ids, str):
material_ids = [material_ids]
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
if formula:
if isinstance(formula, str):
formula = [formula]
query_params.update({"formula": ",".join(formula)})
if chemsys:
if isinstance(chemsys, str):
chemsys = [chemsys]
query_params.update({"chemsys": ",".join(chemsys)})
if elements:
query_params.update({"elements": ",".join(elements)})
if exclude_elements:
query_params.update({"exclude_elements": ",".join(exclude_elements)})
if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
if magnetic_ordering:
query_params.update({"magnetic_ordering": magnetic_ordering.value})
if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)
if is_gap_direct is not None:
query_params.update({"is_gap_direct": is_gap_direct})
if is_metal is not None:
query_params.update({"is_metal": is_metal})
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params,
)
[docs]
class BaseESPropertyRester(BaseRester):
_es_rester: ElectronicStructureRester | None = None
document_model = ElectronicStructureDoc
@property
def es_rester(self) -> ElectronicStructureRester:
if not self._es_rester:
self._es_rester = ElectronicStructureRester(
api_key=self.api_key,
endpoint=self.base_endpoint,
include_user_agent=self.include_user_agent,
session=self.session,
use_document_model=self.use_document_model,
headers=self.headers,
mute_progress_bars=self.mute_progress_bars,
)
return self._es_rester
[docs]
class BandStructureRester(BaseESPropertyRester):
suffix = "materials/electronic_structure/bandstructure"
delta_backed = False
[docs]
def search_bandstructure_summary(self, *args, **kwargs): # pragma: no cover
"""Deprecated."""
warnings.warn(
"MPRester.electronic_structure_bandstructure.search_bandstructure_summary is deprecated. "
"Please use MPRester.electronic_structure_bandstructure.search instead.",
DeprecationWarning,
stacklevel=2,
)
return self.search(*args, **kwargs)
[docs]
def search(
self,
band_gap: tuple[float, float] | None = None,
efermi: tuple[float, float] | None = None,
is_gap_direct: bool | None = None,
is_metal: bool | None = None,
magnetic_ordering: Ordering | str | None = None,
path_type: BSPathType | str = BSPathType.setyawan_curtarolo,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
):
"""Query band structure summary data in electronic structure docs using a variety of search criteria.
Arguments:
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
is_gap_direct (bool): Whether the material has a direct band gap.
is_metal (bool): Whether the material is considered a metal.
magnetic_ordering (Ordering or str): Magnetic ordering of the material.
path_type (BSPathType or str): k-path selection convention for the band structure.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
Default is material_id and last_updated if all_fields is False.
Returns:
([ElectronicStructureDoc]) List of electronic structure documents
"""
query_params: dict = defaultdict(dict)
query_params["path_type"] = (
BSPathType[path_type] if isinstance(path_type, str) else path_type
).value
if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
if magnetic_ordering:
query_params.update(
{
"magnetic_ordering": (
Ordering(magnetic_ordering)
if isinstance(magnetic_ordering, str)
else magnetic_ordering
).value
}
)
if is_gap_direct is not None:
query_params.update({"is_gap_direct": is_gap_direct})
if is_metal is not None:
query_params.update({"is_metal": is_metal})
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params,
)
[docs]
def get_bandstructure_from_task_id(
self,
task_id: str,
run_type: str | RunType | None = None,
path_type: str | BSPathType | None = None,
load_projections: bool = False,
) -> BandStructure:
"""Get the band structure pymatgen object associated with a given task ID.
Arguments:
task_id (str): Task ID for the band structure calculation
run_type (str, RunType, or None): Optional run type,
will speed up query due to delta table partitioning.
path_type (str, BSPathType, or None) : Optional path type to
speed up query
load_projections (bool) : Optionally load atom- and spin-projected
bandstructure, if available.
Returns:
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""
bs_lbl, _ = self._get_delta_table(
"materialsproject-parsed",
"core/electronic-structure/bandstructures/",
label="bandstructure",
)
query = f"""
SELECT *
FROM {bs_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'
"""
if run_type:
rt = RunType(run_type) if isinstance(run_type, str) else run_type
query += f"\nAND run_type='{rt.value}'"
if path_type:
query += f"\nAND path_convention='{path_type}'"
table = self._query_delta_single(query)
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
if load_projections:
proj_bs_label, _ = self._get_delta_table(
"materialsproject-parsed",
"core/electronic-structure/projected-bandstructures/",
label="bandstructure_projections",
)
proj_table = self._query_delta_single(
query.replace(bs_lbl, proj_bs_label)
)
if (
len(deser_proj := proj_table.to_pylist(maps_as_pydicts="strict"))
> 0
):
deser[0]["projections"] = deser_proj[0]
emmet_bs = ElectronicBS(**deser[0])
return emmet_bs.to_pmg(
pmg_cls=BandStructureSymmLine if emmet_bs.labels_dict else BandStructure
)
raise MPRestError(
f"No bandstructure data found for {task_id=}"
+ (f"run_type={rt}" if run_type else "")
)
[docs]
def get_bandstructure_from_material_id(
self,
material_id: str,
path_type: str | BSPathType = BSPathType.setyawan_curtarolo,
line_mode=True,
load_projections: bool = False,
):
"""Get the band structure pymatgen object associated with a Materials Project ID.
Arguments:
material_id (str): Materials Project ID for a material
path_type (BSPathType or its value as a str): k-point path selection convention
line_mode (bool): Whether to return data for a line-mode calculation
load_projections (bool) : Optionally load atom- and spin-projected
bandstructure, if available.
Returns:
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
"""
pt: BSPathType = (
BSPathType(path_type) if isinstance(path_type, str) else path_type
)
if line_mode:
bs_doc = self.es_rester.search(
material_ids=material_id, fields=["bandstructure"]
)
if not bs_doc:
raise MPRestError(
f"No electronic structure data found for material ID {material_id}."
)
if (_bs_data := bs_doc[0]["bandstructure"]) is None:
raise MPRestError(
f"No {pt.value} band structure data found for {material_id}"
)
bs_data = (
_bs_data.model_dump() if self.use_document_model else _bs_data # type: ignore
)
if bs_data.get(pt.value, None) is None:
raise MPRestError(
f"No {pt.value} band structure data found for {material_id}"
)
bs_task_id = bs_data[pt.value]["task_id"]
else:
if not (
bs_doc := self.es_rester.search(
material_ids=material_id, fields=["dos"]
)
):
raise MPRestError(
f"No electronic structure data found for material ID {material_id}."
)
if (_bs_data := bs_doc[0]["dos"]) is None:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
bs_data = _bs_data.model_dump() if self.use_document_model else _bs_data
if bs_data.get("total", None) is None:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
bs_task_id = bs_data["task_id"]
bs_obj = self.get_bandstructure_from_task_id(
bs_task_id,
path_type=pt if line_mode else BSPathType.unknown,
load_projections=load_projections,
)
if bs_obj:
return bs_obj
raise MPRestError("No band structure object found.")
[docs]
class DosRester(BaseESPropertyRester):
suffix = "materials/electronic_structure/dos"
delta_backed = False
[docs]
def search_dos_summary(self, *args, **kwargs): # pragma: no cover
"""Deprecated."""
warnings.warn(
"MPRester.electronic_structure_dos.search_dos_summary is deprecated. "
"Please use MPRester.electronic_structure_dos.search instead.",
DeprecationWarning,
stacklevel=2,
)
return self.search(*args, **kwargs)
[docs]
def search(
self,
band_gap: tuple[float, float] | None = None,
efermi: tuple[float, float] | None = None,
element: Element | str | None = None,
magnetic_ordering: Ordering | str | None = None,
orbital: OrbitalType | str | None = None,
projection_type: DOSProjectionType | str = DOSProjectionType.total,
spin: Spin | str = Spin.up,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
):
"""Query density of states summary data in electronic structure docs using a variety of search criteria.
Arguments:
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
element (Element or str): Element for element-projected dos data.
magnetic_ordering (Ordering or str): Magnetic ordering of the material.
orbital (OrbitalType or str): Orbital for orbital-projected dos data.
projection_type (DOSProjectionType or str): Projection type of dos data. Default is the total dos.
spin (Spin or str): Spin channel of dos data. If non spin-polarized data is stored in Spin.up
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
Default is material_id and last_updated if all_fields is False.
Returns:
([ElectronicStructureDoc]) List of electronic structure documents
"""
query_params: dict = defaultdict(dict)
query_params["projection_type"] = (
DOSProjectionType[projection_type]
if isinstance(projection_type, str)
else projection_type
).value
query_params["spin"] = (Spin[spin] if isinstance(spin, str) else spin).value
if (
query_params["projection_type"] == DOSProjectionType.elemental.value
and element is None
):
raise MPRestError(
"To query element-projected DOS, you must also specify the `element` onto which the DOS is projected."
)
if (
query_params["projection_type"] == DOSProjectionType.orbital.value
and orbital is None
):
raise MPRestError(
"To query orbital-projected DOS, you must also specify the `orbital` character onto which the DOS is projected."
)
if element:
query_params["element"] = (
Element[element] if isinstance(element, str) else element
).value
if orbital:
query_params["orbital"] = (
OrbitalType[orbital] if isinstance(orbital, str) else orbital
).value
if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
if magnetic_ordering:
query_params.update(
{
"magnetic_ordering": (
Ordering[magnetic_ordering]
if isinstance(magnetic_ordering, str)
else magnetic_ordering
).value
}
)
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params,
)
[docs]
def get_dos_from_task_id(
self,
task_id: str,
run_type: str | RunType | None = None,
load_projections: bool = False,
) -> Dos:
"""Get the density of states pymatgen object associated with a given calculation ID.
Arguments:
task_id (str): Task ID for the density of states calculation
run_type (str, RunType, or None): Optional run type to query by.
Will speed up query due to delta table partitioning.
load_projections (bool) : Optionally load atom- and spin-orbital-projected
DOS, if available.
Returns:
pymatgen Dos
"""
dos_lbl, _ = self._get_delta_table(
"materialsproject-parsed",
"core/electronic-structure/total-dos/",
label="total_dos",
)
query = f"""
SELECT *
FROM {dos_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'
"""
if run_type:
rt = RunType(run_type) if isinstance(run_type, str) else run_type
query += f"\nAND run_type='{rt.value}'"
table = self._query_delta_single(query)
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
if load_projections:
proj_dos_label, _ = self._get_delta_table(
"materialsproject-parsed",
"core/electronic-structure/projected-dos/",
label="dos_projections",
)
proj_table = self._query_delta_single(
query.replace(dos_lbl, proj_dos_label)
)
if (
len(deser_proj := proj_table.to_pylist(maps_as_pydicts="strict"))
> 0
):
deser[0]["projected_densities"] = deser_proj[0]
return ElectronicDos(**deser[0]).to_pmg()
raise MPRestError(
f"No DOS data found for {task_id=}" + (f"run_type={rt}" if run_type else "")
)
[docs]
def get_dos_from_material_id(
self, material_id: str, load_projections: bool = False
) -> Dos:
"""Get the complete density of states pymatgen object associated with a Materials Project ID.
Arguments:
material_id (str): Materials Project ID for a material
load_projections (bool) : Optionally load atom- and spin-orbital-projected
DOS, if available.
Returns:
pymatgen Dos
"""
if not (
dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"])
):
raise MPRestError(
f"No electronic structure data found for material ID {material_id}."
)
if not (dos_data := dos_doc[0].get("dos")):
raise MPRestError(f"No density of states data found for {material_id}")
dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[
"task_id"
]
return self.get_dos_from_task_id(dos_task_id, load_projections=load_projections)