Source code for rxn_network.enumerators.utils

"""Utility functions used by the reaction enumerator classes."""

from __future__ import annotations

from typing import TYPE_CHECKING

from pymatgen.analysis.interface_reactions import GrandPotentialInterfacialReactivity, InterfacialReactivity
from pymatgen.entries.computed_entries import ComputedEntry

from rxn_network.reactions.computed import ComputedReaction
from rxn_network.reactions.open import OpenComputedReaction
from rxn_network.utils.funcs import get_logger

if TYPE_CHECKING:
    from collections.abc import Iterable

    from pymatgen.analysis.phase_diagram import GrandPotentialPhaseDiagram, PhaseDiagram
    from pymatgen.core.periodic_table import Element
    from pymatgen.entries.computed_entries import Entry

    from rxn_network.core import Composition
    from rxn_network.entries.entry_set import GibbsEntrySet
    from rxn_network.enumerators.base import Enumerator
    from rxn_network.reactions.base import Reaction

logger = get_logger(__name__)


[docs] def get_computed_rxn( rxn: Reaction, entries: GibbsEntrySet, chempots: dict[Element, float] | None = None ) -> ComputedReaction | OpenComputedReaction: """Provided with a Reaction object and a list of possible entries, this function returns a new ComputedReaction object containing a selection of those entries. Args: rxn: Reaction object entries: Iterable of entries chempots: Optional dictionary of chemical potentials (will return an OpenComputedReaction if supplied). Returns: A ComputedReaction object transformed from a normal Reaction object """ reactant_entries = [entries.get_min_entry_by_formula(r.reduced_formula) for r in rxn.reactants] product_entries = [entries.get_min_entry_by_formula(p.reduced_formula) for p in rxn.products] if chempots: rxn = OpenComputedReaction.balance(reactant_entries, product_entries, chempots) else: rxn = ComputedReaction.balance(reactant_entries, product_entries) return rxn
[docs] def react_interface( r1: Composition, r2: Composition, filtered_entries: GibbsEntrySet, pd: PhaseDiagram, grand_pd: GrandPotentialPhaseDiagram | None = None, ): """Simple API for InterfacialReactivity module from pymatgen.""" chempots = None if grand_pd: interface = GrandPotentialInterfacialReactivity( r1, r2, grand_pd, pd_non_grand=pd, norm=True, include_no_mixing_energy=True, use_hull_energy=True, ) chempots = grand_pd.chempots else: interface = InterfacialReactivity( r1, r2, pd, use_hull_energy=True, ) rxns = [] for _, _, _, rxn, _ in interface.get_kinks(): rxn = get_computed_rxn(rxn, filtered_entries, chempots) rxns.append(rxn) return rxns
[docs] def get_elems_set(entries: Iterable[Entry]) -> set[str]: """Returns chemical system as a set of element names, for set of entries. Args: entries: An iterable of entry-like objects Returns: Set of element names (strings). """ return {str(elem) for e in entries for elem in e.composition.elements}
[docs] def get_total_chemsys_str(entries: Iterable[Entry], open_elems: Iterable[Element] | None = None) -> str: """Returns chemical system string for set of entries, with optional open element. Args: entries: An iterable of entry-like objects open_elems: optional open elements to include in chemical system """ elements = {elem for entry in entries for elem in entry.composition.elements} if open_elems: elements.update(list(open_elems)) return "-".join(sorted([str(e) for e in elements]))
[docs] def group_by_chemsys(combos: Iterable[tuple[Entry, ...]], open_elems: Iterable[Element] | None = None) -> dict: """Groups entry combinations by chemical system, with optional open element. Args: combos: Iterable of entry combinations open_elems: optional open elements to include in chemical system grouping Returns: Dictionary of entry combos grouped by chemical system """ combo_dict: dict[str, list[tuple[Entry, ...]]] = {} for combo in combos: key = get_total_chemsys_str(combo, open_elems) if key in combo_dict: combo_dict[key].append(combo) else: combo_dict[key] = [combo] return combo_dict
[docs] def stabilize_entries(pd: PhaseDiagram, entries_to_adjust: Iterable[Entry], tol: float = 1e-6) -> list[Entry]: """Simple method for stabilizing entries by decreasing their energy to be on the hull. WARNING: This method is not guaranteed to work *simultaneously* for all entries due to the fact that stabilization of one entry may destabilize others. Use with caution. Args: pd: PhaseDiagram object entries_to_adjust: Iterable of entries requiring energies to be adjusted tol: Numerical tolerance to ensure that the energy of the entry is below the hull Returns: A list of new entries with energies adjusted to be on the hull """ indices = [pd.all_entries.index(entry) for entry in entries_to_adjust] new_entries = [] for _, entry in zip(indices, entries_to_adjust): e_above_hull = pd.get_e_above_hull(entry) entry_dict = entry.to_dict() entry_dict["energy"] = entry.uncorrected_energy + (e_above_hull * entry.composition.num_atoms - tol) new_entry = ComputedEntry.from_dict(entry_dict) new_entries.append(new_entry) return new_entries
[docs] def run_enumerators(enumerators: Iterable[Enumerator], entries: GibbsEntrySet): """Utility method for calling enumerate() for a list of enumerators on a particular set of entries. Reaction sets are automatically combined and duplicates are filtered. Args: enumerators: an iterable of enumerators to use for reaction enumeration entries: an entry set to provide to the enumerate() function. """ rxn_set = None for idx, enumerator in enumerate(enumerators): logger.info(f"Running {enumerator.__class__.__name__}") rxns = enumerator.enumerate(entries) logger.info(f"Adding {len(rxns)} reactions to reaction set") rxn_set = rxns if idx == 0 else rxn_set.add_rxn_set(rxns) # type: ignore logger.info("Completed reaction enumeration. Filtering duplicates...") if rxn_set is not None: rxn_set = rxn_set.filter_duplicates() logger.info("Completed duplicate filtering.") return rxn_set