"""Implementation of reaction network and graph classes."""
from __future__ import annotations
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING
import rustworkx as rx
from pymatgen.entries import Entry
from tqdm import tqdm
from rxn_network.costs.functions import Softplus
from rxn_network.entries.experimental import ExperimentalReferenceEntry
from rxn_network.network.base import Graph, Network
from rxn_network.network.entry import NetworkEntry, NetworkEntryType
from rxn_network.pathways.basic import BasicPathway
from rxn_network.pathways.pathway_set import PathwaySet
from rxn_network.reactions.computed import ComputedReaction
from rxn_network.reactions.open import OpenComputedReaction
from rxn_network.utils.funcs import get_logger
if TYPE_CHECKING:
from collections.abc import Iterable
from rxn_network.costs.base import CostFunction
from rxn_network.reactions.base import Reaction
from rxn_network.reactions.reaction_set import ReactionSet
logger = get_logger(__name__)
[docs]
class ReactionNetwork(Network):
"""Main reaction network class for building graph networks and performing
pathfinding. Graphs are built using the rustworkx package (a NetworkX equivalent
implemented in Rust).
If you use this code in your own work, please consider citing this paper:
McDermott, M. J.; Dwaraknath, S. S.; Persson, K. A. A Graph-Based Network for
Predicting Chemical Reaction Pathways in Solid-State Materials Synthesis. Nature
Communications 2021, 12 (1), 3097. https://doi.org/10.1038/s41467-021-23339-x.
"""
def __init__(
self,
rxns: ReactionSet,
cost_function: CostFunction | None = None,
):
"""Initialize a ReactionNetwork object for a reaction set and cost function.
To build the graph network, call the build() method in-place.
Args:
rxns: Set of reactions used to construct the network.
cost_function: The function used to calculate the cost of each reaction
edge. Defaults to a Softplus function with default settings (i.e.
energy_per_atom only).
"""
if cost_function is None:
cost_function = Softplus()
super().__init__(rxns=rxns, cost_function=cost_function)
[docs]
def build(self) -> None:
"""In-place method. Construct the reaction network graph object and store under the
"graph" attribute.
WARNING: This does NOT initialize the precursors or target attributes; you must
call set_precursors() or set_target() to do so.
Returns:
None
"""
logger.info("Building graph from reactions...")
g = Graph()
nodes, edges = get_rxn_nodes_and_edges(self.rxns)
edges.extend(get_loopback_edges(nodes)) # type: ignore
g.add_nodes_from(nodes)
g.add_edges_from(edges)
logger.info(f"Built graph with {g.num_nodes()} nodes and {g.num_edges()} edges")
self._g = g # type: ignore
[docs]
def find_pathways(self, targets: list[Entry | str], k: int = 15) -> list[BasicPathway]:
"""Find the k-shortest paths to a provided list of one or more targets.
Args:
targets: List of the formulas or entry objects of each target.
k: Number of k-shortest paths to find for each target. Defaults to 15.
Returns:
List of BasicPathway objects to all provided targets.
"""
if not self.precursors:
raise AttributeError("Must call set_precursors() before pathfinding!")
paths = []
for target in targets:
self.set_target(target)
print(
f"Paths to {self.target.composition.reduced_formula} \n" # type: ignore
)
print("--------------------------------------- \n")
pathways = self._k_shortest_paths(k=k)
paths.extend(pathways)
return PathwaySet.from_paths(paths)
[docs]
def set_precursors(self, precursors: Iterable[Entry | str]):
"""In-place method. Sets the precursors of the network. Removes all references to
previous precursors.
If entries are provided, will use the entries to set the precursors. If strings
are provided, will automatically find the lowest-energy entries with matching
reduced_formula.
Args:
precursors: iterable of entries/formulas of precursor phases.
Returns:
None
"""
g = self._g
if not g:
raise ValueError("Must call build() before setting precursors!")
precursors = {
p if isinstance(p, (Entry, ExperimentalReferenceEntry)) else self.entries.get_min_entry_by_formula(p)
for p in precursors
}
if precursors == self.precursors:
return
if not all(p in self.entries for p in precursors):
raise ValueError("One or more precursors are not included in network!")
precursors_entry = NetworkEntry(precursors, NetworkEntryType.Precursors)
if self.precursors: # remove old precursors
for node in g.node_indices():
if g.get_node_data(node).description.value == NetworkEntryType.Precursors.value:
g.remove_node(node)
break
else:
raise ValueError("Old precursors node not found in graph!")
precursors_node = g.add_node(precursors_entry)
edges_to_add = []
for node in g.node_indices():
entry = g.get_node_data(node)
entry_type = entry.description.value
if entry_type == NetworkEntryType.Reactants.value:
if entry.entries.issubset(precursors):
edges_to_add.append((precursors_node, node, "precursor_edge"))
elif entry.description.value == NetworkEntryType.Products.value:
for node2 in g.node_indices():
entry2 = g.get_node_data(node2)
if entry2.description.value == NetworkEntryType.Reactants.value:
if precursors.issuperset(entry2.entries):
continue
if precursors.union(entry.entries).issuperset(entry2.entries):
edges_to_add.append((node, node2, "loopback_edge"))
g.add_edges_from(edges_to_add)
self._precursors = precursors
[docs]
def set_target(self, target: Entry | str) -> None:
"""In-place method. Can only provide one target entry or formula at a time.
If entry is provided, will use that entry to set the target. If string is
provided, will automatically find minimum-energy entry with
matching reduced_formula.
Args:
target: Entry, or string of reduced formula, of target phase.
Returns:
None
"""
g = self._g
if not g:
raise ValueError("Must call build() before setting target!")
target = (
target
if isinstance(target, (Entry, ExperimentalReferenceEntry))
else self.entries.get_min_entry_by_formula(target)
)
if target == self.target:
return
if target not in self.entries:
raise ValueError("Target is not included in network!")
if self.target:
for node in g.node_indices():
if g.get_node_data(node).description.value == NetworkEntryType.Target.value:
g.remove_node(node)
break
else:
raise ValueError("Old target node not found in graph!")
target_entry = NetworkEntry([target], NetworkEntryType.Target)
target_node = g.add_node(target_entry)
edges_to_add = []
for node in g.node_indices():
entry = g.get_node_data(node)
entry_type = entry.description.value
if entry_type != NetworkEntryType.Products.value:
continue
if target in entry.entries:
edges_to_add.append((node, target_node, "target_edge"))
g.add_edges_from(edges_to_add)
self._target = target
def _k_shortest_paths(self, k: int):
"""Wrapper for finding the k shortest paths using Yen's algorithm. Returns
BasicPathway objects.
"""
g = self._g
if not g:
raise ValueError("Must call build() before pathfinding!")
paths = []
precursors_node = g.find_node_by_weight(
NetworkEntry(self.precursors, NetworkEntryType.Precursors) # type: ignore
)
target_node = g.find_node_by_weight(NetworkEntry([self.target], NetworkEntryType.Target))
for path in yens_ksp(g, self.cost_function, k, precursors_node, target_node):
paths.append(self._path_from_graph(g, path, self.cost_function))
for path in paths:
print(path, "\n")
return paths
@staticmethod
def _path_from_graph(g, path, cf: CostFunction):
"""Gets a BasicPathway object from a shortest path found in the network."""
rxns = []
costs = []
for step, node in enumerate(path):
if g.get_node_data(node).description.value == NetworkEntryType.Products.value:
e = g.get_edge_data(path[step - 1], node)
rxns.append(e)
costs.append(get_edge_weight(e, cf))
return BasicPathway(reactions=rxns, costs=costs)
[docs]
def get_rxn_nodes_and_edges(
rxns: ReactionSet,
) -> tuple[list[NetworkEntry], list[tuple[int, int, Reaction]]]:
"""Given a reaction set, return a list of nodes and edges for constructing the
reaction network.
Args:
rxns: a list of enumerated ComputedReaction objects to build a network from.
Returns:
A tuple consisting of (nodes, edges) where nodes is a list of NetworkEntry
objects and edges is a list of tuples of the form (source_idx, target_idx,
reaction).
"""
nodes, edges = [], []
for rxn in tqdm(rxns):
reactant_node = NetworkEntry(rxn.reactant_entries, NetworkEntryType.Reactants)
product_node = NetworkEntry(rxn.product_entries, NetworkEntryType.Products)
if reactant_node not in nodes:
nodes.append(reactant_node)
reactant_idx = len(nodes) - 1
else:
reactant_idx = nodes.index(reactant_node)
if product_node not in nodes:
nodes.append(product_node)
product_idx = len(nodes) - 1
else:
product_idx = nodes.index(product_node)
edges.append((reactant_idx, product_idx, rxn))
return nodes, edges
[docs]
def get_loopback_edges(
nodes: list[NetworkEntry],
) -> list[tuple[int, int, str]]:
"""Given a list of nodes to check, this function finds and returns loopback
edges (i.e., edges that connect a product node to its equivalent reactant node).
Args:
nodes: List of vertices from which to find loopback edges
Returns:
A list of tuples of the form (source_idx, target_idx, reaction)
"""
edges = []
for idx1, p in enumerate(nodes):
if p.description.value != NetworkEntryType.Products.value:
continue
for idx2, r in enumerate(nodes):
if r.description.value != NetworkEntryType.Reactants.value:
continue
if p.entries == r.entries:
edges.append((idx1, idx2, "loopback_edge"))
return edges
[docs]
def get_edge_weight(edge_obj: object, cf: CostFunction):
"""Given an edge of a reaction network, calculates the cost/weight of that edge.
Corresponds to zero for loopback & precursor/target edges. Evaluates cost function
for all reaction edges.
Args:
edge_obj: An edge in the reaction network
cf: Cost function for evaluating edge weights
"""
if isinstance(edge_obj, str) and edge_obj in [
"loopback_edge",
"precursor_edge",
"target_edge",
]:
return 0.0
if isinstance(edge_obj, (ComputedReaction, OpenComputedReaction)):
return cf.evaluate(edge_obj)
raise ValueError("Unknown edge type")
[docs]
def yens_ksp(
g: rx.PyGraph,
cf: CostFunction,
num_k: int,
precursors_node: int,
target_node: int,
) -> list[list[int]]:
"""Yen's Algorithm for k-shortest paths, adopted for rustworkx.
This implementation was inspired by the igraph implementation by Antonin Lenfant.
Reference (original Yen's KSP paper):
Jin Y. Yen, "Finding the K Shortest Loopless Paths n a Network", Management
Science, Vol. 17, No. 11, Theory Series (Jul., 1971), pp. 712-716.
Args:
g: the rustworkx PyGraph object.
cf: A cost function for evaluating the edge weights.
num_k: number of k shortest paths that should be found.
precursors_node: the index of the node representing the precursors.
target_node: the index of the node representing the targets.
Returns:
List of lists of graph vertices corresponding to each shortest path
(sorted in increasing order by cost).
"""
def path_cost(nodes):
"""Calculates path cost given a list of nodes."""
cost = 0
for j in range(len(nodes) - 1):
cost += get_edge_weight(g.get_edge_data(nodes[j], nodes[j + 1]), cf)
return cost
def get_edge_weight_with_cf(edge_obj):
"""Includes user-specified cost function in function call."""
return get_edge_weight(edge_obj, cf)
g = g.copy()
path = rx.dijkstra_shortest_paths( # type: ignore
g, precursors_node, target_node, weight_fn=get_edge_weight_with_cf
)
if not path:
return []
path = list(path[target_node])
a = [path]
a_costs = [path_cost(path)]
b = PriorityQueue() # type: ignore
for k in range(1, num_k):
try:
prev_path = a[k - 1]
except IndexError:
logger.info(f"Identified only k={k-1} paths before exiting. \n")
break
for i in range(len(prev_path) - 1):
spur_node = prev_path[i]
root_path = prev_path[:i]
removed_edges = []
for path in a:
if len(path) - 1 > i and root_path == path[:i]:
try:
e = g.get_edge_data(path[i], path[i + 1])
except rx.NoEdgeBetweenNodes:
continue
g.remove_edge(path[i], path[i + 1])
removed_edges.append((path[i], path[i + 1], e))
spur_path = rx.dijkstra_shortest_paths( # type: ignore
g, spur_node, target_node, weight_fn=get_edge_weight_with_cf
)
g.add_edges_from(removed_edges)
if spur_path:
total_path = list(root_path) + list(spur_path[target_node])
total_path_cost = path_cost(total_path)
b.put((total_path_cost, total_path))
while True:
try:
cost_, path_ = b.get(block=False)
except Empty:
break
if path_ not in a:
a.append(path_)
a_costs.append(cost_)
break
return a