"""Utility functions for plotting reaction data & performing analysis."""
from __future__ import annotations
from typing import TYPE_CHECKING
import plotly.express as px
import plotly.graph_objects as go
from pymatgen.analysis.chempot_diagram import plotly_layouts
from rxn_network.costs.pareto import get_pareto_front
if TYPE_CHECKING:
from pandas import DataFrame
[docs]
def plot_reaction_scatter(
df: DataFrame,
x: str = "secondary_competition",
y: str = "energy",
z: str | None = None,
color: str | None = None,
plot_pareto: bool = True,
) -> px.scatter:
"""Plot a Plotly scatter plot (2D or 3D) of reactions on various thermodynamic metric
axes. This also constructs the Pareto front on the provided dimensions.
Args:
df: DataFrame with columns: rxn, energy, (primary_competition),
(secondary_competition), (chempot_distance), (added_elems), (dE)
x: Column name to plot on x-axis
y: Column name to plot on y-axis
z: Column name to plot on z-axis
color: Column name to color points by. Defaults to None.
plot_pareto: Whether to plot the Pareto front. Defaults to True.
Returns:
Plotly scatter plot
"""
def get_label_and_units(name):
label, units = "", ""
if name == "energy":
label = r"$\mathsf{Reaction~driving~force} ~\mathrm{\left(\dfrac{\mathsf{eV}}{\mathsf{atom}}\right)}$"
units = "eV/atom"
if z is not None:
label = "Reaction Driving Force"
elif name == "chempot_distance":
label = r"$\Sigma \Delta \mu_{\mathrm{min}} ~ \mathrm{\left(\dfrac{\mathsf{eV}}{\mathsf{atom}}\right)}$"
if z is not None:
label = "Total chemical potential distance"
units = "eV/atom"
elif name == "primary_competition":
label = "Primary Competition"
units = "eV/atom"
elif name == "secondary_competition":
label = "Secondary Competition"
units = "eV/atom"
elif name == "dE":
label = "Uncertainty"
units = "eV/atom"
return label, units
df = df.copy()
df["rxn"] = df["rxn"].astype(str)
if "added_elems" in df:
df["has_added_elems"] = df["added_elems"] != ""
x_label, x_units = get_label_and_units(x)
y_label, y_units = get_label_and_units(y)
z_label, z_units = None, None
cols: tuple = (x, y)
if z is not None:
z_label, z_units = get_label_and_units(z)
cols = (x, y, z)
if plot_pareto:
pareto_df = get_pareto_front(df, metrics=cols)
df = df.loc[~df.index.isin(pareto_df.index)]
arr = pareto_df[list(cols)].to_numpy()
if z is None:
scatter = go.Scatter(
x=arr[:, 0],
y=arr[:, 1],
hovertext=pareto_df["rxn"],
marker={"size": 10, "color": "seagreen", "symbol": "diamond"},
mode="markers",
name="Pareto front",
)
else:
scatter = go.Scatter3d(
x=arr[:, 0],
y=arr[:, 1],
z=arr[:, 2],
hovertext=pareto_df["rxn"],
marker={"size": 10, "color": "seagreen", "symbol": "diamond"},
mode="markers",
name="Pareto front",
)
if z is None:
layout_2d = plotly_layouts["default_layout_2d"]
fig = px.scatter(
df,
x=x,
y=y,
hover_name="rxn",
labels={x: x_label, y: y_label},
color=color,
color_discrete_map={True: "darkorange", False: "lightgray"},
)
fig.update_layout(layout_2d)
else:
layout_3d = plotly_layouts["default_layout_3d"]
axis_layout = plotly_layouts["default_3d_axis_layout"].copy()
axis_layout["titlefont"]["size"] = 14
for t in ["xaxis", "yaxis", "zaxis"]:
layout_3d["scene"][t] = axis_layout
layout_3d["scene_camera"] = {
"eye": {"x": -5, "y": -5, "z": 5}, # zoomed out
"projection": {"type": "orthographic"},
"center": {"x": -0.2, "y": -0.2, "z": -0.1},
}
fig = px.scatter_3d(
df,
x=x,
y=y,
z=z,
hover_name="rxn",
labels={x: x_label, y: y_label, z: z_label},
template="simple_white",
color=color,
color_discrete_map={True: "darkorange", False: "lightgray"},
)
fig.update_layout(layout_3d)
if plot_pareto:
fig.add_trace(scatter)
hovertemplate = (
"<b>%{hovertext}</b><br>"
"<br><b>"
f"{x}"
"</b>: %{x:.3f}"
f" {x_units}"
"<br><b>"
f"{y}"
"</b>: %{y:.3f}"
f" {y_units}"
)
if z is not None:
hovertemplate = hovertemplate + "<br><b>" + f"{z}" + "</b>: %{z:.3f}" + f" {z_units}<br>"
fig.update_traces(hovertemplate=hovertemplate)
return fig
[docs]
def pretty_df_layout(df: DataFrame):
"""Improve visibility for a pandas DataFrame with wide column names."""
return df.style.set_table_styles(
[
{
"selector": "th",
"props": [
("max-width", "70px"),
("text-overflow", "ellipsis"),
("overflow", "hidden"),
],
}
]
) # improve rendering in Jupyter
[docs]
def filter_df_by_precursors(df: DataFrame, precursors: list[str]):
"""Filter a reaction DataFrame by available precursors."""
df = df.copy()
df["precursors"] = [sorted([r.reduced_formula for r in rxn.reactants]) for rxn in df["rxn"]]
selected = df[df["precursors"].apply(lambda x: all(p in precursors for p in x))]
return selected.drop(columns=["precursors"])