Source code for rxn_network.network.visualize
"""Functions for visualizing/plotting reaction networks."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.cm
import numpy as np
from rustworkx.visualization import mpl_draw
if TYPE_CHECKING:
from rustworkx import PyGraph
[docs]
def plot_network(graph: PyGraph, vertex_cmap_name: str = "jet", **kwargs):
"""Plots a reaction network using rustworkx visualization tools (i.e., mpl_draw).
Args:
graph: a rustworkx PyGraph object
vertex_cmap_name: the name of . Defaults to "jet".
**kwargs: keyword arguments to pass to mpl_draw
"""
g = graph.copy()
node_names = [e.chemsys for e in g.nodes()]
color_func_v = _get_cmap_string(vertex_cmap_name, domain=sorted(node_names))
vertex_colors = [color_func_v(chemsys) for chemsys in node_names]
return mpl_draw(
g,
node_size=2,
width=0.2,
arrow_size=3,
node_color=vertex_colors,
alpha=0.8,
**kwargs,
)
def _get_cmap_string(palette, domain):
"""Utility function for getting a matplotlib colormap string for a given palette and
domain.
"""
domain_unique = np.unique(domain)
hash_table = {key: i_str for i_str, key in enumerate(domain_unique)}
mpl_cmap = matplotlib.cm.get_cmap(palette, lut=len(domain_unique))
def cmap_out(X, **kwargs):
return mpl_cmap(hash_table[X], **kwargs)
return cmap_out