"""Implements a reaction pathway solver class which efficiently solves mass balance
equations using matrix operations.
"""
from __future__ import annotations
from abc import ABCMeta
from copy import deepcopy
from itertools import combinations
from typing import TYPE_CHECKING
import numpy as np
import ray
from monty.json import MSONable
from numba import jit
from pymatgen.core.composition import Element
from tqdm import tqdm
from rxn_network.core import Composition
from rxn_network.entries.entry_set import GibbsEntrySet
from rxn_network.enumerators.basic import BasicEnumerator, BasicOpenEnumerator
from rxn_network.enumerators.minimize import MinimizeGibbsEnumerator, MinimizeGrandPotentialEnumerator
from rxn_network.pathways.balanced import BalancedPathway
from rxn_network.pathways.pathway_set import PathwaySet
from rxn_network.reactions.computed import ComputedReaction
from rxn_network.reactions.open import OpenComputedReaction
from rxn_network.reactions.reaction_set import ReactionSet
from rxn_network.utils.funcs import get_logger, grouper
from rxn_network.utils.ray import initialize_ray, to_iterator
if TYPE_CHECKING:
from rxn_network.costs.base import CostFunction
from rxn_network.pathways.base import Pathway
from rxn_network.reactions.base import Reaction
logger = get_logger(__name__)
[docs]
class Solver(MSONable, metaclass=ABCMeta):
"""Base definition for a pathway solver class."""
def __init__(self, pathways: PathwaySet):
"""
Args:
pathways: A PathwaySet object containing the pathways to be combined/balanced.
"""
self._pathways = pathways
rxns = []
costs = []
for path in self._pathways.paths:
for rxn, cost in zip(path.reactions, path.costs):
if rxn not in rxns:
rxns.append(rxn)
costs.append(cost)
self._reactions = rxns
self._costs = costs
@property
def pathways(self) -> list[Pathway]:
"""Pathways used in solver class."""
return self._pathways
@property
def reactions(self) -> list[Reaction]:
"""Reactions used in solver class."""
return self._reactions
@property
def costs(self) -> list[float]:
"""Costs used in solver class."""
return self._costs
@property
def num_rxns(self) -> int:
"""Length of the reaction list."""
return len(self.reactions)
@property
def num_entries(self) -> int:
"""Length of entry list."""
return len(self._entries)
[docs]
class PathwaySolver(Solver):
"""Solver that implements an efficient method (using numba) for finding balanced
reaction pathways from a list of graph-derived reaction pathways (i.e. a list of
lists of reactions).
If you use this code in your own work, please consider citing this paper:
McDermott, M. J.; Dwaraknath, S. S.; Persson, K. A. A Graph-Based Network for
Predicting Chemical Reaction Pathways in Solid-State Materials Synthesis. Nature
Communications 2021, 12 (1), 3097. https://doi.org/10.1038/s41467-021-23339-x.
"""
def __init__(
self,
pathways: PathwaySet,
entries: GibbsEntrySet,
cost_function: CostFunction,
open_elem: str | Element | None = None,
chempot: float = 0.0,
chunk_size: int = 100000,
batch_size: int | None = None,
):
"""
Args:
pathways: List of reaction pathways derived from the network.
entries: GibbsEntrySet containing all entries in the network.
cost_function: CostFunction object to use for the solver.
open_elem: Optional element to use for pathways with an open element.
chempot: Chemical potential to use for pathways with an open element.
Defaults to 0.0.
chunk_size: The number of pathways per chunk to use for balancing. Defaults
to 100,000.
batch_size: Number of chunks to submit to each CPU at a time. Automatically
calculated if not set.
"""
super().__init__(pathways=deepcopy(pathways))
self._entries = entries
self.cost_function = cost_function
self.open_elem = Element(open_elem) if open_elem else None
self.chempot = chempot
self.chunk_size = chunk_size
self.batch_size = batch_size
[docs]
def solve(
self,
net_rxn: ComputedReaction | OpenComputedReaction,
max_num_combos: int = 4,
find_intermediate_rxns: bool = True,
intermediate_rxn_energy_cutoff: float = 0.0,
use_basic_enumerator: bool = True,
use_minimize_enumerator: bool = False,
filter_interdependent: bool = True,
) -> PathwaySet:
"""Args:
net_rxn: The reaction representing the total reaction from precursors to
final targets.
max_num_combos: The maximum allowable size of the balanced reaction pathway.
At values <=5, the solver will start to take a significant amount of
time to run.
find_intermediate_rxns: Whether to find intermediate reactions; crucial for
finding pathways where intermediates react together, as these reactions
may not occur in the graph-derived pathways. Defaults to True.
intermediate_rxn_energy_cutoff: An energy cutoff by which to filter down
intermediate reactions. This can be useful when there are a large number
of possible intermediates. < 0 means allow only exergonic reactions.
use_basic_enumerator: Whether to use the BasicEnumerator to find
intermediate reactions. Defaults to True.
use_minimize_enumerator: Whether to use the MinimizeGibbsEnumerator to find
intermediate reactions. Defaults to False.
filter_interdependent: Whether or not to filter out pathways where reaction
steps are interdependent. Defaults to True.
Returns:
A list of BalancedPathway objects.
"""
if not net_rxn.balanced:
raise ValueError("Net reaction must be balanceable to find all reaction pathways.")
if not ray.is_initialized():
initialize_ray()
entries_copy = deepcopy(self.entries)
entries = entries_copy.entries_list
num_entries = len(entries)
reactions = deepcopy(self.reactions)
costs = deepcopy(self.costs)
precursors = deepcopy(net_rxn.reactant_entries)
targets = deepcopy(net_rxn.product_entries)
logger.info(f"Net reaction: {net_rxn} \n")
if find_intermediate_rxns:
logger.info("Identifying reactions between intermediates...")
intermediate_rxns = self._find_intermediate_rxns(
targets,
intermediate_rxn_energy_cutoff,
use_basic_enumerator,
use_minimize_enumerator,
)
intermediate_costs = [self.cost_function.evaluate(r) for r in intermediate_rxns.get_rxns()]
for r, c in zip(intermediate_rxns, intermediate_costs):
if r not in reactions:
reactions.append(r)
costs.append(c)
clean_r_set = ReactionSet.from_rxns(reactions, filter_duplicates=True)
cleaned_reactions, cleaned_costs = zip(
*[(r, c) for r, c in zip(reactions, costs) if r in clean_r_set and r != net_rxn]
)
net_rxn_vector = net_rxn.get_entry_idx_vector(num_entries)
num_rxns = len(cleaned_reactions)
num_cpus = ray.cluster_resources()["CPU"]
batch_size = self.batch_size or num_cpus - 1
net_coeff_filter = np.argwhere(net_rxn_vector != 0).flatten()
net_coeff_filter = ray.put(net_coeff_filter)
cleaned_reactions_ref = ray.put(cleaned_reactions)
comp_matrices = {n: [] for n in range(1, max_num_combos + 1)} # type: ignore
comp_matrices_refs_dict = {} # type: ignore
for n in range(1, max_num_combos + 1):
comp_matrices_refs_dict[n] = []
for group in grouper(combinations(range(num_rxns), n), self.chunk_size):
comp_matrices_refs_dict[n].append(
_create_comp_matrices.remote(group, cleaned_reactions_ref, num_entries, net_coeff_filter)
)
logger.info("Building comp matrices...")
num_objs = sum(len(i) for i in comp_matrices_refs_dict.values()) # type: ignore
with tqdm(total=num_objs) as pbar:
for n, comp_matrices_refs in comp_matrices_refs_dict.items():
for comp_matrices_ref in to_iterator(comp_matrices_refs):
pbar.update(1)
comp_matrices[n].append(comp_matrices_ref)
comp_matrices[n] = np.concatenate(comp_matrices[n])
if not comp_matrices[n].any(): # type: ignore
del comp_matrices[n]
logger.info("Comp matrices done...")
num_cpu_jobs = 0
c_m_mats = []
c_m_mats_refs = []
num_jobs = sum(len(val) // self.chunk_size + 1 for val in comp_matrices.values())
num_batches = int(num_jobs // batch_size + 1)
batch_count = 1
for n, comp_matrix in comp_matrices.items():
if n >= 4:
num_splits = len(comp_matrix) // self.chunk_size + 1
splits = np.array_split(comp_matrix, num_splits)
else:
splits = [comp_matrix] # only submit one job for small n
for group in splits:
if len(group) == 0: # catch empty matrices
continue
path_balancer = _balance_path_arrays_cpu_wrapper
num_cpu_jobs += 1
c_m_mats_refs.append(
path_balancer.remote(
group,
net_rxn_vector,
)
)
if len(c_m_mats_refs) >= batch_size:
for c_m_mats_ref in tqdm(
to_iterator(c_m_mats_refs),
total=len(c_m_mats_refs),
desc=(f"{self.__class__.__name__} (Batch {batch_count}/{num_batches})"),
):
c_m_mats.append(c_m_mats_ref) # noqa: PERF402
batch_count += 1
num_cpu_jobs = 0
c_m_mats_refs = []
for c_m_mats_ref in tqdm(
to_iterator(c_m_mats_refs),
total=len(c_m_mats_refs),
desc=f"{self.__class__.__name__} (Batch {batch_count}/{num_batches})",
):
c_m_mats.append(c_m_mats_ref) # noqa: PERF402
c_mats, m_mats = zip(*c_m_mats)
c_mats = [mat for mats in c_mats for mat in mats if mat is not None] # type: ignore
m_mats = [mat for mats in m_mats for mat in mats if mat is not None] # type: ignore
paths = []
for c_mat, m_mat in zip(c_mats, m_mats):
path_rxns = []
path_costs = []
for rxn_mat in c_mat:
ents, coeffs = zip(*[(entries[idx], c) for idx, c in enumerate(rxn_mat) if not np.isclose(c, 0.0)])
if self.open_elem is not None:
rxn = OpenComputedReaction(
entries=ents,
coefficients=coeffs,
chempots={self.open_elem: self.chempot},
)
else:
rxn = ComputedReaction(entries=ents, coefficients=coeffs)
try:
path_rxns.append(rxn)
path_costs.append(cleaned_costs[cleaned_reactions.index(rxn)])
except Exception as e:
print(e)
continue
p = BalancedPathway(path_rxns, m_mat.flatten(), path_costs, balanced=True)
paths.append(p)
filtered_paths = []
if filter_interdependent:
precursor_comps = [p.composition for p in precursors]
for p in paths:
interdependent = p.contains_interdependent_rxns(precursor_comps)
if not interdependent:
filtered_paths.append(p)
else:
filtered_paths = paths
filtered_paths = sorted(set(filtered_paths), key=lambda p: p.average_cost)
return PathwaySet.from_paths(filtered_paths)
def _find_intermediate_rxns(
self,
targets,
energy_cutoff,
use_basic_enumerator,
use_minimize_enumerator,
):
"""Method for finding intermediate reactions using enumerators and
specified settings.
"""
intermediates = {e for rxn in self.reactions for e in rxn.entries}
intermediates = GibbsEntrySet(
list(intermediates) + targets,
)
target_formulas = [e.composition.reduced_formula for e in targets]
ref_elems = {e for e in self.entries if e.is_element}
intermediates = intermediates | ref_elems
rxn_set = ReactionSet(
intermediates.entries_list,
{},
{},
open_elem=self.open_elem,
chempot=self.chempot,
all_data={},
)
if use_basic_enumerator:
be = BasicEnumerator(targets=target_formulas)
rxn_set = rxn_set.add_rxn_set(be.enumerate(intermediates))
if self.open_elem:
boe = BasicOpenEnumerator(
open_phases=[Composition(str(self.open_elem)).reduced_formula],
targets=target_formulas,
)
rxn_set = rxn_set.add_rxn_set(boe.enumerate(intermediates))
if use_minimize_enumerator:
mge = MinimizeGibbsEnumerator(
targets=target_formulas,
)
rxn_set = rxn_set.add_rxn_set(mge.enumerate(intermediates))
if self.open_elem:
mgpe = MinimizeGrandPotentialEnumerator(
open_elem=self.open_elem,
mu=self.chempot,
targets=target_formulas,
)
rxn_set.add_rxn_set(mgpe.enumerate(intermediates))
rxns = list(filter(lambda x: x.energy_per_atom < energy_cutoff, rxn_set))
rxns = [r for r in rxns if all(e in intermediates for e in r.entries)]
num_rxns = len(rxns)
rxns = ReactionSet.from_rxns(rxns, filter_duplicates=True)
logger.info(f"Found {num_rxns} intermediate reactions! \n")
return rxns
@property
def entries(self) -> GibbsEntrySet:
"""Entry set used in solver."""
return self._entries
@jit(nopython=True)
def _balance_path_arrays_cpu(
comp_matrices: np.ndarray,
net_coeffs: np.ndarray,
tol: float = 1e-6,
) -> tuple[np.ndarray, np.ndarray]:
"""Fast solution for reaction multiplicities via mass balance stochiometric
constraints. Parallelized using Numba JIT. Can be applied to large batches (100K-1M
sets of reactions at a time.).
Args:
comp_matrices: Array containing stoichiometric coefficients of all
compositions in all reactions, for each trial combination.
net_coeffs: Array containing stoichiometric coefficients of net reaction.
tol: numerical tolerance for determining if a multiplicity is zero
(i.e., if reaction was removed).
"""
shape = comp_matrices.shape
net_coeff_filter = np.argwhere(net_coeffs != 0).flatten()
len_net_coeff_filter = len(net_coeff_filter)
all_multiplicities = np.zeros((shape[0], shape[1]), np.float64)
indices = np.full(shape[0], fill_value=False)
for i in range(shape[0]):
correct = True
for j in range(len_net_coeff_filter):
idx = net_coeff_filter[j]
if not comp_matrices[i][:, idx].any():
correct = False
break
if not correct:
continue
comp_pinv = np.linalg.pinv(comp_matrices[i]).T
multiplicities = comp_pinv @ net_coeffs
solved_coeffs = comp_matrices[i].T @ multiplicities
if (multiplicities < tol).any():
continue
if not (np.abs(solved_coeffs - net_coeffs) <= (1e-08 + 1e-05 * np.abs(net_coeffs))).all():
continue
all_multiplicities[i] = multiplicities
indices[i] = True
filtered_indices = np.argwhere(indices != 0).flatten()
length = filtered_indices.shape[0]
filtered_comp_matrices = np.empty((length, shape[1], shape[2]), np.float64)
filtered_multiplicities = np.empty((length, shape[1]), np.float64)
for i in range(length):
idx = filtered_indices[i]
filtered_comp_matrices[i] = comp_matrices[idx]
filtered_multiplicities[i] = all_multiplicities[idx]
return filtered_comp_matrices, filtered_multiplicities
@ray.remote
def _create_comp_matrices(combos, rxns, num_entries, net_coeff_filter):
"""Create array of stoichiometric coefficients for each reaction."""
comp_matrices = np.stack(
[np.vstack([rxns[r].get_entry_idx_vector(num_entries) for r in combo]) for combo in combos if combo]
)
# filter bad matrices
return comp_matrices[comp_matrices[:, :, net_coeff_filter].any(axis=1).all(axis=1)]
@ray.remote
def _balance_path_arrays_cpu_wrapper(
comp_matrices,
net_rxn_vector,
):
"""Wraps pathway balancing method with ray.remote decorator."""
return _balance_path_arrays_cpu(comp_matrices, net_rxn_vector)