Source code for rxn_network.enumerators.basic

"""This module implements two types of basic (combinatorial) reaction enumerators."""

from __future__ import annotations

from copy import deepcopy
from itertools import combinations, product
from math import comb
from typing import TYPE_CHECKING

import ray
from pymatgen.analysis.phase_diagram import GrandPotentialPhaseDiagram, PhaseDiagram
from tqdm import tqdm

from rxn_network.entries.entry_set import GibbsEntrySet
from rxn_network.entries.utils import initialize_entry
from rxn_network.enumerators.base import Enumerator
from rxn_network.enumerators.utils import group_by_chemsys
from rxn_network.reactions.computed import ComputedReaction
from rxn_network.reactions.reaction_set import ReactionSet
from rxn_network.utils.funcs import get_logger, grouper, limited_powerset
from rxn_network.utils.ray import initialize_ray, to_iterator

if TYPE_CHECKING:
    from pymatgen.entries.computed_entries import ComputedEntry

logger = get_logger(__name__)


[docs] class BasicEnumerator(Enumerator): """Enumerator for finding all simple reactions within a set of entries, up to a maximum reactant/product cardinality (n); i.e., how many phases on either side of the reaction. This approach does not explicitly take into account thermodynamic stability (i.e. phase diagram). This allows for enumeration of reactions where the products may not be stable with respect to each other. If you use this code in your own work, please consider citing this paper: McDermott, M. J.; Dwaraknath, S. S.; Persson, K. A. A Graph-Based Network for Predicting Chemical Reaction Pathways in Solid-State Materials Synthesis. Nature Communications 2021, 12 (1), 3097. https://doi.org/10.1038/s41467-021-23339-x. """ MIN_CHUNK_SIZE = 2500 MAX_NUM_JOBS = 5000 def __init__( self, precursors: list[str] | None = None, targets: list[str] | None = None, n: int = 2, exclusive_precursors: bool = True, exclusive_targets: bool = False, filter_duplicates: bool = False, filter_by_chemsys: str | None = None, chunk_size: int = MIN_CHUNK_SIZE, max_num_jobs: int = MAX_NUM_JOBS, remove_unbalanced: bool = True, remove_changed: bool = True, max_num_constraints: int = 1, quiet: bool = False, ): """Initialize a BasicEnumerator object. Args: precursors: Optional list of precursor formulas. The only reactions that will be enumerated are those featuring one or more of these compositions as reactants. The "exclusive_precursors" parameter allows one to tune whether the enumerated reactions should include ALL precursors (the default) or just one. targets: Optional list of target formulas. The only reactions that will be enumerated are those featuring one or more of these compositions as products. The "exclusive_targets" parameter allows one to tune whether the enumerated reactions should include ALL targets or just one (the default). n: Maximum reactant/product cardinality. This it the largest possible number of entries on either side of the reaction. Defaults to 2. exclusive_precursors: Whether enumerated reactions are required to have reactants that are a subset of the provided list of precursors. If True (the default), this only identifies reactions with reactants selected from the provided precursors. exclusive_targets: Whether enumerated reactions are required to have products that are a subset of the provided list of targets. If False, (the default), this identifies all reactions containing at least one composition from the provided list of targets (and any number of byproducts). filter_duplicates: Whether to remove duplicate reactions. Defaults to False. filter_by_chemsys: An optional chemical system for which to filter produced reactions by. This ensures that all output reactions contain at least one element within the provided system. chunk_size: The minimum number of reactions per chunk procssed. Needs to be sufficiently large to make parallelization a cost-effective strategy. Defaults to MIN_CHUNK_SIZE. max_num_jobs: The upper limit for the number of jobs created. Defaults to MAX_NUM_JOBS. remove_unbalanced: Whether to remove reactions which are unbalanced; this is usually advisable. Defaults to True. remove_changed: Whether to remove reactions which can only be balanced by removing a reactant/product or having it change sides. This is also advisable for ensuring that only unique reaction sets are produced. Defaults to True. max_num_constraints: The maximum number of allowable constraints enforced by reaction balancing. Defaults to 1 (which is usually advisable). quiet: Whether to run in quiet mode (no progress bar). Defaults to False. """ super().__init__(precursors=precursors, targets=targets) self.n = n self.exclusive_precursors = exclusive_precursors self.exclusive_targets = exclusive_targets self.filter_duplicates = filter_duplicates self.filter_by_chemsys = filter_by_chemsys self.chunk_size = chunk_size self.max_num_jobs = max_num_jobs self.remove_unbalanced = remove_unbalanced self.remove_changed = remove_changed self.max_num_constraints = max_num_constraints self.quiet = quiet self._stabilize = False self._p_set_func = "issuperset" if self.exclusive_precursors else "intersection" self._t_set_func = "issuperset" if self.exclusive_targets else "intersection" self.open_phases: list | None = None self._build_pd = False self._build_grand_pd = False
[docs] def enumerate( self, entries: GibbsEntrySet, ) -> ReactionSet: """Calculate all possible reactions given a set of entries. If the enumerator was initialized with specified precursors or target, the reactions will be filtered by these constraints. Every enumerator follows a standard procedure. Steps: 1) Initialize entries, i.e., ensure that precursors and target are considered stable entries within the entry set. 2) Get a dictionary representing every possible "node", i.e. phase combination, grouped by chemical system. 3) Filter the combos dictionary for chemical systems which are not relevant; i.e., don't contain elements in precursors and/or target. 4) Iterate through each chemical system, initializing calculators, and computing all possible reactions for reactant/product pair and/or thermodynamically predicted reactions for given reactants. 5) Add reactions to growing list, repeat Step 4 until combos dict exhausted. Args: entries: the set of all entries to enumerate from """ if not ray.is_initialized(): initialize_ray() entries, precursors, targets, open_entries = self._get_initialized_entries(entries) combos_dict = self._get_combos_dict( entries, precursors, targets, open_entries, ) open_combos = self._get_open_combos(open_entries) if not open_combos: open_combos = [] precursors = ray.put(precursors) targets = ray.put(targets) react_function = ray.put(self._react_function) open_entries = ray.put(open_entries) p_set_func = ray.put(self._p_set_func) t_set_func = ray.put(self._t_set_func) remove_unbalanced = ray.put(self.remove_unbalanced) remove_changed = ray.put(self.remove_changed) max_num_constraints = ray.put(self.max_num_constraints) entries_ref = ray.put(entries) num_cpus = int(ray.cluster_resources()["CPU"]) pd_dict = {} if self._build_pd or self._build_grand_pd: # pre-loop for phase diagram construction pd_chunk_size = int(len(combos_dict) // num_cpus) + 1 pd_dict_refs = [] for item_chunk in grouper(combos_dict.items(), pd_chunk_size): pd_dict_refs.append( _get_entries_and_pds.remote( item_chunk, entries_ref, self.build_pd, self.build_grand_pd, getattr(self, "chempots", None), ) ) for completed in tqdm( to_iterator(pd_dict_refs), total=len(pd_dict_refs), disable=self.quiet, desc=f"Building phase diagrams ({self.__class__.__name__})", ): pd_dict.update(completed) chunk_size = self.chunk_size total = sum(self._rxn_iter_length(c, open_combos) for c in combos_dict.values()) if total / chunk_size > self.max_num_jobs: chunk_size = int(total // self.max_num_jobs) + 1 logger.info(f"Increasing chunk size to {chunk_size} due to max job limit of {self.max_num_jobs}") to_run, current_chunk = [], [] # type: ignore for item in tqdm( combos_dict.items(), disable=self.quiet, desc="Building chunks...", total=len(combos_dict), ): chemsys, combos = item rxn_iter = list(self._get_rxn_iterable(combos, open_combos)) filtered_entries, pd, grand_pd = None, None, None if self._build_pd or self._build_grand_pd: filtered_entries, pd, grand_pd = pd_dict[chemsys] filtered_entries = ray.put(filtered_entries) pd = ray.put(pd) grand_pd = ray.put(grand_pd) current_chunk_length = sum(len(c[0]) for c in current_chunk) # type: ignore current_chunk.append(([], filtered_entries, pd, grand_pd)) for r in rxn_iter: if current_chunk_length == chunk_size: to_run.append(current_chunk) current_chunk = [([r], filtered_entries, pd, grand_pd)] current_chunk_length = 1 else: current_chunk[-1][0].append(r) current_chunk_length += 1 if current_chunk_length == chunk_size: to_run.append(current_chunk) current_chunk = [] if current_chunk: to_run.append(current_chunk) rxn_chunk_refs, results = [], [] # type: ignore for chunk in to_run: rxn_chunk_refs.append( _react.remote( chunk, react_function, open_entries, precursors, targets, p_set_func, t_set_func, remove_unbalanced, remove_changed, max_num_constraints, ) ) for completed in tqdm( to_iterator(rxn_chunk_refs), total=len(to_run), disable=self.quiet, desc=f"Enumerating reactions ({self.__class__.__name__})", ): results.extend(completed) return ReactionSet.from_rxns(results, entries=entries, filter_duplicates=self.filter_duplicates)
@classmethod def _get_num_chunks(cls, items, open_combos, chunk_size) -> int: _ = open_combos # not used in BasicEnumerator num_chunks = 0 for _, i in items: num_combos = cls._rxn_iter_length(i, open_combos) num_chunks += num_combos // chunk_size + bool(num_combos % chunk_size) return num_chunks @staticmethod def _rxn_iter_length(combos, open_combos) -> int: _ = open_combos # not used in BasicEnumerator return comb(len(combos), 2) def _get_combos_dict(self, entries, precursor_entries, target_entries, open_entries): """Gets all possible entry combinations up to predefined cardinality (n), filtered and grouped by chemical system. """ precursor_elems = [[str(el) for el in e.composition.elements] for e in precursor_entries] target_elems = [[str(el) for el in e.composition.elements] for e in target_entries] all_open_elems = {el for e in open_entries for el in e.composition.elements} entries = entries - open_entries combos = [set(c) for c in limited_powerset(entries, self.n)] combos_dict = group_by_chemsys(combos, all_open_elems) return self._filter_dict_by_elems( combos_dict, precursor_elems, target_elems, all_open_elems, ) def _get_open_combos( # pylint: disable=useless-return self, open_entries ) -> list[set[ComputedEntry]] | None: """No open entries for BasicEnumerator, returns None.""" _ = (self, open_entries) # unused return None @staticmethod def _react_function(reactants, products, **kwargs): _ = kwargs # unused forward_rxn = ComputedReaction.balance(reactants, products) backward_rxn = forward_rxn.reverse() return [forward_rxn, backward_rxn] @staticmethod def _get_rxn_iterable(combos, open_combos): """Get all reaction/product combinations.""" _ = open_combos # unused return combinations(combos, 2) def _get_initialized_entries(self, entries): """Returns initialized entries, precursors, target, and open entries.""" def initialize_entries_list(ents): return {initialize_entry(f, entries, self.stabilize) for f in ents} precursors, targets = set(), set() entries_new = GibbsEntrySet( deepcopy(entries), ) if self.precursors: precursors = initialize_entries_list(self.precursors) if self.targets: targets = initialize_entries_list(self.targets) for e in precursors | targets: if e not in entries_new: try: old_e = entries_new.get_min_entry_by_formula(e.composition.reduced_formula) entries_new.discard(old_e) except KeyError: pass entries_new.add(e) if self.stabilize: entries_new = entries_new.filter_by_stability(e_above_hull=0.0) logger.info("Filtering by stable entries!") entries_new.build_indices() open_entries = set() if self.open_phases: open_entries = {e for e in entries_new if e.composition.reduced_formula in self.open_phases} return entries_new, precursors, targets, open_entries def _filter_dict_by_elems( self, combos_dict, precursor_elems, target_elems, all_open_elems, ): """Filters the dictionary of combinations by elements.""" filtered_dict = {} all_precursor_elems = {el for g in precursor_elems for el in g} all_target_elems = {el for g in target_elems for el in g} all_open_elems = {str(el) for el in all_open_elems} filter_elems = None if self.filter_by_chemsys: filter_elems = set(self.filter_by_chemsys.split("-")) for chemsys, combos in combos_dict.items(): elems = set(chemsys.split("-")) if filter_elems and not elems.issuperset(filter_elems): continue if len(elems) >= 10 or len(elems) == 1: # too few or too many elements continue if precursor_elems and not getattr(all_precursor_elems | all_open_elems, self._p_set_func)(elems): continue if target_elems and not getattr(all_target_elems | all_open_elems, self._t_set_func)(elems): continue filtered_dict[chemsys] = combos return filtered_dict @property def stabilize(self) -> bool: """Whether or not to use only stable entries in analysis.""" return self._stabilize @property def build_pd(self) -> bool: """Whether or not to build a PhaseDiagram object during reaction enumeration (useful for some analyses).""" return self._build_pd @property def build_grand_pd(self) -> bool: """Whether or not to build a GrandPotentialPhaseDiagram object during reaction enumeration (useful for some analyses). """ return self._build_grand_pd
[docs] class BasicOpenEnumerator(BasicEnumerator): """Enumerator for finding all simple reactions within a set of entries, up to a maximum reactant/product cardinality (n), with any number of open phases. Note: this does not return OpenComputedReaction objects (this can be calculated using the ReactionSet class). If you use this code in your own work, please consider citing this paper: McDermott, M. J.; Dwaraknath, S. S.; Persson, K. A. A Graph-Based Network for Predicting Chemical Reaction Pathways in Solid-State Materials Synthesis. Nature Communications 2021, 12 (1), 3097. https://doi.org/10.1038/s41467-021-23339-x. """ MIN_CHUNK_SIZE = 2500 MAX_NUM_JOBS = 5000 def __init__( self, open_phases: list[str], precursors: list[str] | None = None, targets: list[str] | None = None, n: int = 2, exclusive_precursors: bool = True, exclusive_targets: bool = False, filter_duplicates: bool = False, filter_by_chemsys: str | None = None, chunk_size: int = MIN_CHUNK_SIZE, max_num_jobs: int = MAX_NUM_JOBS, remove_unbalanced: bool = True, remove_changed: bool = True, max_num_constraints: int = 1, quiet: bool = False, ): """Supplied target and calculator parameters are automatically initialized as objects during enumeration. Args: open_phases: List of formulas of open entries (e.g. ["O2"]). precursors: Optional list of precursor formulas. The only reactions that will be enumerated are those featuring one or more of these compositions as reactants. The "exclusive_precursors" parameter allows one to tune whether the enumerated reactions should include ALL precursors (the default) or just one. targets: Optional list of target formulas. The only reactions that will be enumerated are those featuring one or more of these compositions as products. The "exclusive_targets" parameter allows one to tune whether the enumerated reactions should include ALL targets or just one (the default). n: Maximum reactant/product cardinality. This it the largest possible number of entries on either side of the reaction. Defaults to 2. exclusive_precursors: Whether enumerated reactions are required to have reactants that are a subset of the provided list of precursors. If True (the default), this only identifies reactions with reactants selected from the provided precursors. exclusive_targets: Whether enumerated reactions are required to have products that are a subset of the provided list of targets. If False, (the default), this identifies all reactions containing at least one composition from the provided list of targets (and any number of byproducts). filter_duplicates: Whether to remove duplicate reactions. Defaults to False. filter_by_chemsys: An optional chemical system for which to filter produced reactions by. This ensures that all output reactions contain at least one element within the provided system. chunk_size: The minimum number of reactions per chunk procssed. Needs to be sufficiently large to make parallelization a cost-effective strategy. Defaults to MIN_CHUNK_SIZE. max_num_jobs: The upper limit for the number of jobs created. Defaults to MAX_NUM_JOBS. remove_unbalanced: Whether to remove reactions which are unbalanced; this is usually advisable. Defaults to True. remove_changed: Whether to remove reactions which can only be balanced by removing a reactant/product or having it change sides. This is also advisable for ensuring that only unique reaction sets are produced. Defaults to True. max_num_constraints: The maximum number of allowable constraints enforced by reaction balancing. Defaults to 1 (which is usually advisable). quiet: Whether to run in quiet mode (no progress bar). Defaults to False. """ super().__init__( precursors=precursors, targets=targets, n=n, exclusive_precursors=exclusive_precursors, exclusive_targets=exclusive_targets, filter_duplicates=filter_duplicates, filter_by_chemsys=filter_by_chemsys, chunk_size=chunk_size, max_num_jobs=max_num_jobs, remove_unbalanced=remove_unbalanced, remove_changed=remove_changed, max_num_constraints=max_num_constraints, quiet=quiet, ) self.open_phases: list[str] = open_phases @staticmethod def _rxn_iter_length(combos, open_combos): num_combos_with_open = sum(1 if not i & j else 0 for i in combos for j in open_combos) return len(combos) * num_combos_with_open def _get_open_combos(self, open_entries): """Get all possible combinations of open entries. For a single entry, this is just the entry itself. """ return [set(c) for c in limited_powerset(open_entries, len(open_entries))] @staticmethod def _get_rxn_iterable(combos, open_combos): """Get all reaction/product combinations.""" combos_with_open = [ combo | open_combo for combo in combos for open_combo in open_combos if not combo & open_combo ] return product(combos, combos_with_open)
@ray.remote def _react( chunk, react_function, open_entries, precursors, targets, p_set_func, t_set_func, remove_unbalanced, remove_changed, max_num_constraints, ): """This function is a wrapper for the specific react function of each enumerator. This wrapper contains the logic used for filtering out reactions based on the user-defined enumerator settings. It can also be called as a remote function using ray, allowing for parallel computation during reaction enumeration. WARNING: this function is not intended to to be called directly by the user and should only be used by the enumerator classes. """ all_rxns = [] for rxn_iterable, filtered_entries, pd, grand_pd in chunk: filtered_entries = ray.get(filtered_entries) pd = ray.get(pd) grand_pd = ray.get(grand_pd) for rp in rxn_iterable: if not rp: continue r = set(rp[0]) if rp[0] else set() p = set(rp[1]) if rp[1] else set() all_phases = r | p precursor_func = getattr(precursors | open_entries, p_set_func) if precursors else lambda e: True target_func = getattr(targets | open_entries, t_set_func) if targets else lambda e: True if (r & p) or (precursors and not precursors & all_phases) or (p and targets and not targets & all_phases): continue if not (precursor_func(r) or (p and precursor_func(p))): continue if p and not (target_func(r) or target_func(p)): continue suggested_rxns = react_function(r, p, filtered_entries=filtered_entries, pd=pd, grand_pd=grand_pd) rxns = [] for rxn in suggested_rxns: if ( rxn.is_identity or (remove_unbalanced and not rxn.balanced) or (remove_changed and rxn.lowest_num_errors != 0) or rxn.data["num_constraints"] > max_num_constraints ): continue reactant_entries = set(rxn.reactant_entries) - open_entries product_entries = set(rxn.product_entries) - open_entries if precursor_func(reactant_entries) and target_func(product_entries): rxns.append(rxn) all_rxns.extend(rxns) return all_rxns @ray.remote def _get_entries_and_pds(combos_dict_chunk, entries, build_pd, build_grand_pd, chempots): pd_dict = {} for item in combos_dict_chunk: if item is None: continue chemsys, _ = item elems = chemsys.split("-") filtered_entries = None pd = None grand_pd = None if build_pd or build_grand_pd: filtered_entries = entries.get_subset_in_chemsys(elems) if build_pd: pd = PhaseDiagram(filtered_entries) if build_grand_pd: grand_pd = GrandPotentialPhaseDiagram(filtered_entries, chempots) pd_dict[chemsys] = (filtered_entries, pd, grand_pd) return pd_dict