Source code for rxn_network.network.base

"""Basic interface for a reaction network and its graph."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

from monty.json import MontyDecoder, MSONable
from rustworkx import PyDiGraph

from rxn_network.entries.entry_set import GibbsEntrySet

if TYPE_CHECKING:
    from collections.abc import Iterable

    from pymatgen.entries import Entry

    from rxn_network.costs.base import CostFunction
    from rxn_network.pathways.base import Pathway
    from rxn_network.reactions.reaction_set import ReactionSet


[docs] class Network(MSONable, metaclass=ABCMeta): """Base definition for a reaction network.""" def __init__( self, rxns: ReactionSet, cost_function: CostFunction, ): """ Args: rxns: A ReactionSet object containing the reactions that form the network edges. cost_function: A CostFunction object used to evaluate reaction properties and assign edge weights. """ self.rxns = rxns self.cost_function = cost_function self.entries = GibbsEntrySet(rxns.entries) self.entries.build_indices() self._precursors = None self._target = None self._g = None
[docs] @abstractmethod def build(self) -> None: """Construct the network in-place from the supplied enumerators."""
[docs] @abstractmethod def find_pathways(self, target, k) -> list[Pathway]: """Find reaction pathways."""
[docs] @abstractmethod def set_precursors(self, precursors: Iterable[Entry | str]) -> None: """Set the phases used as precursors in the network (in-place)."""
[docs] @abstractmethod def set_target(self, target: Entry | str) -> None: """Set the phase used as a target in the network (in-place)."""
[docs] def as_dict(self) -> dict: """Returns MSONable dict for serialization. See monty package for more information. """ d = super().as_dict() d["precursors"] = list(self.precursors) if self.precursors else None d["target"] = self.target d["graph"] = self.graph.as_dict() return d
[docs] @classmethod def from_dict(cls, d: dict) -> Network: """Instantiate object from MSONable dict. See monty package for more information. """ precursors = d.pop("precursors", None) target = d.pop("target", None) graph = d.pop("graph", None) network = super().from_dict(d) network._precursors = precursors # pylint: disable=protected-access network._target = target # pylint: disable=protected-access network._g = MontyDecoder().process_decoded( # pylint: disable=protected-access graph ) return network
@property def precursors(self) -> set[Entry] | None: """The phases used as precursors in the network.""" return self._precursors @property def target(self) -> Entry | None: """The phase used as a target in the network.""" return self._target @property def graph(self): """Returns the network's Graph object.""" return self._g @property def chemsys(self) -> str: """A string representing the chemical system (elements) of the network.""" return "-".join(sorted(self.entries.chemsys)) def __repr__(self) -> str: return f"Reaction network for chemical system: {self.chemsys}, with {self.graph!s}" def __str__(self) -> str: return self.__repr__()
[docs] class Graph(PyDiGraph): """Thin wrapper around rx.PyDiGraph to allow for serialization and optimized database storage. """
[docs] def as_dict(self) -> dict: """Represents the PyDiGraph object as a serializable dictionary. See monty package (MSONable) for more information. """ d = {"@module": self.__class__.__module__, "@class": self.__class__.__name__} d["nodes"] = [n.as_dict() for n in self.nodes()] # type: ignore d["node_indices"] = list(self.node_indices()) # type: ignore d["edges"] = [ (*e, obj.as_dict() if hasattr(obj, "as_dict") else obj) # type: ignore for e, obj in zip(self.edge_list(), self.edges()) ] return d
[docs] @classmethod def from_dict(cls, d: dict) -> Graph: """Instantiates a Graph object from a dictionary. See as_dict() and monty package (MSONable) for more information. """ nodes = MontyDecoder().process_decoded(d["nodes"]) node_indices = MontyDecoder().process_decoded(d["node_indices"]) edges = [(e[0], e[1], MontyDecoder().process_decoded(e[2])) for e in d["edges"]] nodes = dict(zip(nodes, node_indices)) graph = cls() new_indices = graph.add_nodes_from(list(nodes.keys())) mapping = {nodes[node]: idx for idx, node in zip(new_indices, nodes.keys())} new_mapping = [] for edge in edges: new_mapping.append((mapping[edge[0]], mapping[edge[1]], edge[2])) graph.add_edges_from(new_mapping) return graph
def __repr__(self) -> str: return f"Graph with {self.num_nodes()} nodes and {self.num_edges()} edges"