Source code for rgpycrumbs.eon.plt_neb

#!/usr/bin/env python3
"""Plots Nudged Elastic Band (NEB) reaction paths and landscapes.

.. versionadded:: 0.0.2

This script provides a command-line interface (CLI) to visualize data
generated from NEB calculations. It can plot:

1.  **Energy/Eigenvalue Profiles:** Shows the evolution of the energy or
    lowest eigenvalue along the reaction coordinate. It can overlay multiple
    paths (e.g., from different optimization steps) and use a
    physically-motivated Hermite spline interpolation using force data.

2.  **2D Reaction Landscapes:** Plots the path on a 2D coordinate system
    defined by the Root Mean Square Deviation (RMSD) from the reactant
    and product structures. This requires the 'ira_mod' library.
    It can also interpolate and display the 2D energy/eigenvalue surface.

The script can also render atomic structures from a .con file as insets
on the plots for key points (reactant, saddle, product).

This script follows the guidelines laid out here:
https://realpython.com/python-script-structure/
"""

# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "click",
#   "matplotlib",
#   "numpy",
#   "scipy",
#   "jax",
#   "adjustText",
#   "cmcrameri",
#   "rich",
#   "ase",
#   "polars",
#   "h5py",
#   "chemparseplot",
# ]
# ///

import logging
import sys
from pathlib import Path

import click
import matplotlib as mpl
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from adjustText import adjust_text
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ArrowStyle
from rich.logging import RichHandler

from chemparseplot.parse.eon.neb import (
    aggregate_neb_landscape_data,
    compute_profile_rmsd,
    estimate_rbf_smoothing,
    load_structures_and_calculate_additional_rmsd,
)

# --- Library Imports ---
from chemparseplot.parse.file_ import find_file_paths
from chemparseplot.parse.trajectory.hdf5 import (
    history_to_landscape_df as hdf5_history_to_landscape_df,
)
from chemparseplot.parse.trajectory.hdf5 import (
    history_to_profile_dats,
    result_to_atoms_list,
)
from chemparseplot.parse.trajectory.hdf5 import (
    result_to_profile_dat as hdf5_result_to_profile_dat,
)
from chemparseplot.parse.trajectory.neb import (
    load_trajectory,
    trajectory_to_landscape_df,
    trajectory_to_profile_dat,
)
from chemparseplot.plot.neb import (
    plot_energy_path,
    plot_landscape_path_overlay,
    plot_landscape_surface,
    plot_structure_inset,
    plot_structure_strip,
)
from chemparseplot.plot.theme import (
    apply_axis_theme,
    get_theme,
    setup_global_theme,
)

# --- Logging Setup ---
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s - %(message)s",
    handlers=[RichHandler(rich_tracebacks=True, show_path=False, markup=True)],
)
[docs] log = logging.getLogger("rich")
# --- Constants ---
[docs] DEFAULT_INPUT_PATTERN = "neb_*.dat"
[docs] DEFAULT_PATH_PATTERN = "neb_path_*.con"
[docs] IRA_KMAX_DEFAULT = 1.8
# --- CLI --- @click.command() @click.option( "--input-dat-pattern", default=DEFAULT_INPUT_PATTERN, help="Glob pattern for input data files.", ) @click.option( "--input-path-pattern", default=DEFAULT_PATH_PATTERN, help="Glob pattern for input path files.", ) @click.option( "--con-file", type=click.Path(exists=True, dir_okay=False, path_type=Path), default=None, help="Path to .con trajectory file.", ) @click.option( "--additional-con", type=( click.Path(exists=True, dir_okay=False, path_type=Path), str, ), # Takes (Path, Label) multiple=True, default=None, help="Path(s) to additional .con file(s) and label.", ) @click.option( "--augment-dat", type=str, default=None, help="Glob pattern for extra .dat files for surface fitting.", ) @click.option( "--augment-con", type=str, default=None, help="Glob pattern for extra .con files for surface fitting.", ) @click.option( "--sp-file", type=click.Path(exists=False, dir_okay=False, path_type=Path), default=Path("sp.con"), help="Path to explicit saddle point file (eOn sp.con).", ) @click.option( "--source", type=click.Choice(["eon", "traj", "hdf5"]), default="eon", help="Data source: 'eon' for .dat/.con pairs, 'traj' for extxyz, 'hdf5' for ChemGP HDF5.", ) @click.option( "--input-h5", type=click.Path(exists=True, dir_okay=False, path_type=Path), default=None, help="Path to ChemGP NEB HDF5 file (result or history).", ) @click.option( "--input-traj", type=click.Path(exists=True, dir_okay=False, path_type=Path), default=None, help="Path to extxyz trajectory file (used with --source traj).", ) @click.option( "--plot-type", type=click.Choice(["profile", "landscape"]), default="profile", help="Type of plot to generate.", ) @click.option( "--rbf-smoothing", type=float, default=None, show_default=True, help="Smoothing term for 2D RBF.", ) @click.option( "--landscape-mode", type=click.Choice(["path", "surface"]), default="surface", help="For landscape plot: 'path' or 'surface'.", ) @click.option( "--landscape-path", type=click.Choice(["last", "all"]), default="all", help="Last uses an interpolation only on the last path, otherwise use all points.", ) @click.option( "--project-path/--no-project-path", is_flag=True, default=True, help="Project landscape coordinates into the reaction valley (s, d).", ) @click.option( "--rc-mode", type=click.Choice(["path", "rmsd", "index"]), default="path", help="Reaction coordinate for profile plot.", ) @click.option( "--plot-structures", type=click.Choice(["none", "all", "crit_points"]), default="none", help="Structures to render on the path. Requires --con-file.", ) @click.option( "--surface-type", type=click.Choice( [ "grid", "rbf", "grad_matern", "grad_imq", "grad_imq_ny", "matern", "imq", "grad_rq", "grad_se", ] ), default="rbf", help="Interpolation method for the 2D surface.", ) @click.option( "--n-inducing", type=int, default=None, help="Number of inducing points for Nystrom or RFF features. Defaults to 300 (Nystrom) or 500 (RFF).", ) @click.option( "--show-pts/--no-show-pts", default=True, help="Show all paths from the optimization on the RMSD 2D plot.", ) @click.option( "--plot-mode", type=click.Choice(["energy", "eigenvalue"]), default="energy", help="Quantity to plot.", ) @click.option( "-o", "--output-file", type=click.Path(path_type=Path), default=None, help="Output image filename.", ) @click.option( "--start", type=int, default=None, help="Start file index for profile plot." ) @click.option("--end", type=int, default=None, help="End file index for profile plot.") @click.option( "--normalize-rc", is_flag=True, default=False, help="Normalize reaction coordinate." ) @click.option("--title", default="NEB Path", help="Plot title.") @click.option("--xlabel", default=None, help="X-axis label.") @click.option("--ylabel", default=None, help="Y-axis label.") # --- Theme and Override Options --- @click.option( "--theme", default="ruhi", help="The plotting theme to use.", ) @click.option("--cmap-profile", default=None, help="Colormap for profile plot.") @click.option("--cmap-landscape", default=None, help="Colormap for landscape plot.") @click.option("--facecolor", type=str, default=None, help="Background color.") @click.option("--fontsize-base", type=int, default=None, help="Base font size.") # --- Figure and Inset Options --- @click.option( "--figsize", nargs=2, type=(float, float), default=(5.37, 5.37), show_default=True, help="Figure width, height in inches.", ) @click.option( "--fig-height", type=float, default=None, help="Figure height in inches.", ) @click.option( "--aspect-ratio", type=float, default=None, help="Figure aspect ratio.", ) @click.option( "--dpi", type=int, default=200, show_default=True, help="Resolution in Dots Per Inch.", ) @click.option( "--zoom-ratio", type=float, default=0.4, show_default=True, help="Scale the inset image.", ) @click.option( "--ase-rotation", type=str, default="0x, 90y, 0z", show_default=True, help="ASE rotation string.", ) @click.option( "--strip-renderer", type=click.Choice(["ase", "xyzrender"]), default="ase", show_default=True, help="Rendering backend for structure images.", ) @click.option( "--arrow-head-length", type=float, default=0.2, show_default=True, help="Arrow head length.", ) @click.option( "--arrow-head-width", type=float, default=0.3, show_default=True, help="Arrow head width.", ) @click.option( "--arrow-tail-width", type=float, default=0.1, show_default=True, help="Arrow tail width.", ) # --- Path/Spline Options --- @click.option( "--highlight-last/--no-highlight-last", is_flag=True, default=True, help="Highlight last path.", ) @click.option( "--spline-method", type=click.Choice(["hermite", "spline"]), default="hermite", help="Spline interpolation method.", ) # --- Inset Position Options --- @click.option( "--draw-reactant", type=(float, float, float), nargs=3, default=(15, 60, 0.1), show_default=True, help="Reactant inset pos (x, y, rad).", ) @click.option( "--draw-saddle", type=(float, float, float), nargs=3, default=(15, 60, 0.1), show_default=True, help="Saddle inset pos (x, y, rad).", ) @click.option( "--draw-product", type=(float, float, float), nargs=3, default=(15, 60, 0.1), show_default=True, help="Product inset pos (x, y, rad).", ) @click.option( "--cache-file", type=click.Path(path_type=Path), default=Path(".neb_landscape.parquet"), help="Parquet cache file.", ) @click.option( "--force-recompute", is_flag=True, default=False, help="Force re-calculation of RMSD.", ) @click.option( "--show-legend", is_flag=True, default=False, help="Show the legends.", ) @click.option(
[docs] "--ira-kmax", default=IRA_KMAX_DEFAULT, help="kmax factor for IRA.", ) def main( # --- Input Files --- input_dat_pattern, input_path_pattern, con_file, additional_con, # --- Data Source --- source, input_traj, input_h5, # --- Plot Behavior --- plot_type, landscape_mode, landscape_path, project_path, rc_mode, plot_structures, rbf_smoothing, show_pts, plot_mode, surface_type, n_inducing, # --- Output & Slicing --- output_file, start, end, # --- Plot Aesthetics --- normalize_rc, title, xlabel, ylabel, highlight_last, # --- Theme --- theme, cmap_profile, cmap_landscape, facecolor, fontsize_base, # --- Figure & Inset --- figsize, fig_height, aspect_ratio, dpi, zoom_ratio, ase_rotation, strip_renderer, arrow_head_length, arrow_head_width, arrow_tail_width, # --- Spline --- spline_method, # --- Inset Positions --- draw_reactant, draw_saddle, draw_product, show_legend, # Caching cache_file, force_recompute, ira_kmax, sp_file, augment_dat, augment_con, ): """Main entry point for NEB plot script.""" # 1. Setup Theme active_theme = get_theme( theme, cmap_profile=cmap_profile, cmap_landscape=cmap_landscape, font_size=fontsize_base, facecolor=facecolor, ) setup_global_theme(active_theme) if fig_height and aspect_ratio: figsize = (fig_height * aspect_ratio, fig_height) elif fig_height or aspect_ratio: log.error( "Both --fig-height and --aspect-ratio must be provided together. Using default figsize." ) fig = plt.figure(figsize=figsize, dpi=dpi) # Layout Logic has_strip = plot_structures in ["all", "crit_points"] and plot_type == "landscape" if has_strip: # Heuristic layout adjustment n_expected = (3 if plot_structures == "crit_points" else 10) + len( additional_con or [] ) max_cols = 6 n_rows = (n_expected + max_cols - 1) // max_cols calc_hspace = 0.8 if n_rows > 1 else 0.3 gs = GridSpec(2, 1, height_ratios=[1, 0.25], hspace=calc_hspace, figure=fig) ax = fig.add_subplot(gs[0]) ax_strip = fig.add_subplot(gs[1]) apply_axis_theme(ax_strip, active_theme) else: ax = fig.add_subplot(111) ax_strip = None apply_axis_theme(ax, active_theme) atoms_list = None additional_atoms_data = [] sp_data = None # Only attempt to load structures if specifically requested or needed for the plot type if con_file: try: atoms_list, additional_atoms_data, sp_data = ( load_structures_and_calculate_additional_rmsd( con_file, additional_con, ira_kmax, sp_file ) ) except Exception as e: log.error(f"Error loading structures: {e}") # Critical failure for landscape/RMSD modes if plot_type == "landscape" or rc_mode == "rmsd": log.critical("Cannot proceed without structures. Exiting.") sys.exit(1) # --- Trajectory source: load once if applicable --- traj_atoms_list = None if source == "traj": if not input_traj: log.critical("--input-traj is required when --source traj is used.") sys.exit(1) traj_atoms_list = load_trajectory(str(input_traj)) if plot_type == "landscape": # --- Landscape Plot --- z_label = ( "Relative Energy (eV)" if plot_mode == "energy" else r"Lowest Eigenvalue (eV/$\AA^2$)" ) if source == "traj": df = trajectory_to_landscape_df(traj_atoms_list, ira_kmax=ira_kmax) # Use traj structures for con_file features when not provided if atoms_list is None: atoms_list = traj_atoms_list elif source == "hdf5": if not input_h5: log.critical("--input-h5 is required when --source hdf5 is used.") sys.exit(1) h5_str = str(input_h5) # Prefer history file for multi-step landscape try: df = hdf5_history_to_landscape_df(h5_str, ira_kmax=ira_kmax) except Exception: log.warning("History read failed, falling back to single-step result.") from chemparseplot.parse.trajectory.hdf5 import ( result_to_atoms_list as _r2a, ) from chemparseplot.parse.trajectory.neb import ( trajectory_to_landscape_df as _traj_ldf, ) hdf5_atoms = _r2a(h5_str) df = _traj_ldf(hdf5_atoms, ira_kmax=ira_kmax) if atoms_list is None: atoms_list = result_to_atoms_list(h5_str) else: dat_paths = find_file_paths(input_dat_pattern) con_paths = find_file_paths(str(input_path_pattern)) if not dat_paths: log.critical(f"No data files found for pattern: {input_dat_pattern}") sys.exit(1) # Fallback if no path files found but main file exists if not con_paths and con_file: con_paths = [con_file] y_col = 2 if plot_mode == "energy" else 4 df = aggregate_neb_landscape_data( dat_paths, con_paths, y_col, None, cache_file=cache_file, force_recompute=force_recompute, ira_kmax=ira_kmax, augment_dat=augment_dat, augment_con=augment_con, ref_atoms=atoms_list[0] if atoms_list else None, # main reactant prod_atoms=atoms_list[-1] if atoms_list else None, # main product ) # Surface Generation if landscape_mode == "surface": if landscape_path == "last": max_step = df["step"].max() df_surface = df.filter(pl.col("step") == max_step) else: df_surface = df # Prepare arrays r_all = df_surface["r"].to_numpy() p_all = df_surface["p"].to_numpy() z_all = df_surface["z"].to_numpy() gr_all = df_surface["grad_r"].to_numpy() gp_all = df_surface["grad_p"].to_numpy() step_all = df_surface["step"].to_numpy() # Heuristic for RBF smoothing if missing if rbf_smoothing is None: rbf_smoothing = estimate_rbf_smoothing(df) log.info(f"Calculated heuristic RBF smoothing: {rbf_smoothing:.4f}") extra_pts = [] if sp_data: extra_pts.append([sp_data["r"], sp_data["p"]]) for _, add_r, add_p, _ in additional_atoms_data: extra_pts.append([add_r, add_p]) extra_pts_arr = np.array(extra_pts) if extra_pts else None plot_landscape_surface( ax, r_all, p_all, gr_all, gp_all, z_all, step_data=step_all, method=surface_type, rbf_smooth=rbf_smoothing, cmap=active_theme.cmap_landscape, show_pts=show_pts, # so we always show 5% and 95%, this is the user defined additional one # TODO(rg): just be a user parameter.. variance_threshold=0.5, # 50% uncertainty project_path=project_path, extra_points=extra_pts_arr, n_inducing=n_inducing, ) # Path Overlay (Final Step) max_step = df["step"].max() df_final = df.filter(pl.col("step") == max_step) final_r = df_final["r"].to_numpy() final_p = df_final["p"].to_numpy() final_z = df_final["z"].to_numpy() plot_landscape_path_overlay( ax, final_r, final_p, final_z, active_theme.cmap_landscape, z_label, project_path=project_path, ) # Saddle Point Marker if sp_data: # Use explicit SP coordinates sp_x_raw, sp_y_raw = sp_data["r"], sp_data["p"] log.info(f"Plotting explicit SP at R={sp_x_raw:.3f}, P={sp_y_raw:.3f}") else: # Fallback to heuristic if plot_mode == "energy": saddle_idx = np.argmax(final_z[1:-1]) + 1 else: saddle_idx = np.argmin(final_z) sp_x_raw, sp_y_raw = final_r[saddle_idx], final_p[saddle_idx] # Apply projection to saddle point if enabled if project_path: r_start, p_start = final_r[0], final_p[0] r_end, p_end = final_r[-1], final_p[-1] vec_r = r_end - r_start vec_p = p_end - p_start path_norm = np.hypot(vec_r, vec_p) u_r, u_p = vec_r / path_norm, vec_p / path_norm v_r, v_p = -u_p, u_r sp_x = (sp_x_raw - r_start) * u_r + (sp_y_raw - p_start) * u_p sp_y = (sp_x_raw - r_start) * v_r + (sp_y_raw - p_start) * v_p else: sp_x, sp_y = sp_x_raw, sp_y_raw ax.scatter( sp_x, sp_y, marker="s", s=int(active_theme.font_size * 2), c="white", edgecolors="black", linewidths=1.5, zorder=100, label="SP", ) if additional_atoms_data: marker_cmap = mpl.colormaps.get_cmap("tab10") for i, (_, add_r, add_p, add_label) in enumerate(additional_atoms_data): color = marker_cmap(i % 10) if project_path: plot_add_r = (add_r - r_start) * u_r + (add_p - p_start) * u_p plot_add_p = (add_r - r_start) * v_r + (add_p - p_start) * v_p else: plot_add_r, plot_add_p = add_r, add_p ax.plot( plot_add_r, plot_add_p, marker="*", markersize=int(active_theme.font_size * 1.1), color=color, markeredgecolor="white", markeredgewidth=1.0, linestyle="None", zorder=102, label=add_label, ) if has_strip and atoms_list: strip_payload = [] # Helper to calculate projected coordinates for labels def get_projected_coords(r_val, p_val): if project_path: s_val = (r_val - r_start) * u_r + (p_val - p_start) * u_p d_val = (r_val - r_start) * v_r + (p_val - p_start) * v_p return s_val, d_val return r_val, p_val # Add Reactant rx, ry = get_projected_coords(final_r[0], final_p[0]) strip_payload.append({"atoms": atoms_list[0], "x": rx, "y": ry, "label": "R"}) # Add Saddle (Explicit or Heuristic) if sp_data: sx, sy = get_projected_coords(sp_data["r"], sp_data["p"]) strip_payload.append( {"atoms": sp_data["atoms"], "x": sx, "y": sy, "label": "SP"} ) else: s_idx = ( (np.argmax(final_z[1:-1]) + 1) if plot_mode == "energy" else np.argmin(final_z) ) sx, sy = get_projected_coords(final_r[s_idx], final_p[s_idx]) strip_payload.append( {"atoms": atoms_list[s_idx], "x": sx, "y": sy, "label": "SP"} ) # Add Product px, py = get_projected_coords(final_r[-1], final_p[-1]) strip_payload.append( {"atoms": atoms_list[-1], "x": px, "y": py, "label": "P"} ) # Add intermediate points if 'all' requested if plot_structures == "all": for i in range(1, len(atoms_list) - 1): ix, iy = get_projected_coords(final_r[i], final_p[i]) strip_payload.append( {"atoms": atoms_list[i], "x": ix, "y": iy, "label": str(i)} ) # Add additional structures for add_atoms, add_r, add_p, add_label in additional_atoms_data: ax_r, ax_p = get_projected_coords(add_r, add_p) strip_payload.append( { "atoms": add_atoms, "x": ax_r, "y": ax_p, "label": add_label, } ) strip_payload.sort(key=lambda d: d["x"]) labels = [d["label"] for d in strip_payload] structs = [d["atoms"] for d in strip_payload] plot_structure_strip( ax_strip, structs, labels, zoom=zoom_ratio, rotation=ase_rotation, theme_color=active_theme.textcolor, renderer=strip_renderer, ) # Annotate Main Plot -- only label R, SP, P (not additional con; # those are identified by the legend markers instead) main_plot_texts = [] main_labels = {"R", "SP", "P"} for d in strip_payload: if d["label"] not in main_labels: continue t = ax.text( d["x"], d["y"], d["label"], fontsize=11, fontweight="bold", color="white", ha="center", va="bottom", zorder=102, ) t.set_path_effects( [path_effects.withStroke(linewidth=2.5, foreground="black")] ) main_plot_texts.append(t) if main_plot_texts: adjust_text( main_plot_texts, ax=ax, arrowprops={"arrowstyle": "-", "color": "white", "lw": 1.0}, expand_points=(2.0, 2.0), force_text=(1.0, 2.0), force_points=(1.0, 2.0), ) # Labels if project_path: final_xlabel = xlabel or r"Reaction progress $s$ ($\AA$)" final_ylabel = ylabel or r"Orthogonal deviation $d$ ($\AA$)" final_title = "Reaction Valley Projection" if title == "NEB Path" else title else: final_xlabel = xlabel or r"RMSD from Reactant ($\AA$)" final_ylabel = ylabel or r"RMSD from Product ($\AA$)" final_title = "RMSD(R,P) projection" if title == "NEB Path" else title else: # --- Profile Plot --- if source == "hdf5": if not input_h5: log.critical("--input-h5 is required when --source hdf5 is used.") sys.exit(1) h5_str = str(input_h5) # Use history final step if available, else result try: dats = history_to_profile_dats(h5_str) data = dats[-1] except Exception: data = hdf5_result_to_profile_dat(h5_str) if atoms_list is None: atoms_list = result_to_atoms_list(h5_str) if rc_mode == "index": data[1] = np.arange(data.shape[1]) elif normalize_rc: data[1] = data[1] / data[1].max() if data[1].max() > 0 else data[1] y_col = 2 if plot_mode == "energy" else 4 color = active_theme.highlight_color plot_energy_path( ax, data[1], data[y_col], data[3], color, 1.0, 20, method=spline_method, ) elif source == "traj": # Trajectory source: single extxyz file -> one profile data = trajectory_to_profile_dat(traj_atoms_list) if atoms_list is None: atoms_list = traj_atoms_list if rc_mode == "index": data[1] = np.arange(data.shape[1]) elif normalize_rc: data[1] = data[1] / data[1].max() if data[1].max() > 0 else data[1] y_col = 2 if plot_mode == "energy" else 4 color = active_theme.highlight_color plot_energy_path( ax, data[1], data[y_col], data[3], color, 1.0, 20, method=spline_method, ) if atoms_list and plot_structures != "none": indices = ( list(range(len(atoms_list))) if plot_structures == "all" else sorted( { 0, np.argmax(data[y_col][1:-1]) + 1 if plot_mode == "energy" else np.argmin(data[y_col]), len(atoms_list) - 1, } ) ) for i in indices: if i == 0: xybox, rad = draw_reactant[:2], draw_reactant[2] elif i == len(atoms_list) - 1: xybox, rad = draw_product[:2], draw_product[2] else: xybox, rad = draw_saddle[:2], draw_saddle[2] if plot_structures == "all": xybox = (15.0, 60.0 if i % 2 == 0 else -60.0) rad = 0.1 if i % 2 == 0 else -0.1 plot_structure_inset( ax, atoms_list[i], data[1][i], data[y_col][i], xybox, rad, zoom=zoom_ratio, rotation=ase_rotation, renderer=strip_renderer, ) else: # eOn source: multiple .dat files dat_paths = find_file_paths(input_dat_pattern) file_paths_to_plot = dat_paths[start:end] if not file_paths_to_plot: log.error("No files found in range.") sys.exit(1) # Optional: Load RMSD for X-axis rmsd_rc = None if rc_mode == "rmsd" and atoms_list: df_rmsd = compute_profile_rmsd( atoms_list, cache_file, force_recompute, ira_kmax ) rmsd_rc = df_rmsd["r"].to_numpy() # Plot Loop cm = plt.get_cmap(active_theme.cmap_profile) color_divisor = ( len(file_paths_to_plot) - 1 if len(file_paths_to_plot) > 1 else 1.0 ) y_col = 2 if plot_mode == "energy" else 4 for idx, fpath in enumerate(file_paths_to_plot): try: data = np.loadtxt(fpath, skiprows=1).T except Exception as ex: log.error(ex) continue # X-Axis Logic if rc_mode == "rmsd" and rmsd_rc is not None: if len(rmsd_rc) == data.shape[1]: data[1] = rmsd_rc elif rc_mode == "index": data[1] = np.arange(data.shape[1]) elif normalize_rc: data[1] = data[1] / data[1].max() if data[1].max() > 0 else data[1] # Style Logic is_last = idx == len(file_paths_to_plot) - 1 if highlight_last and is_last: color, alpha, zorder = active_theme.highlight_color, 1.0, 20 else: color = cm(idx / color_divisor) alpha = 1.0 if idx == 0 else 0.5 zorder = 10 if idx == 0 else 5 # Plot plot_energy_path( ax, data[1], data[y_col], data[3], # Forces color, alpha, zorder, method=spline_method, ) if ( highlight_last and is_last and atoms_list and plot_structures != "none" ): indices = ( list(range(len(atoms_list))) if plot_structures == "all" else sorted( { 0, np.argmax(data[y_col][1:-1]) + 1 if plot_mode == "energy" else np.argmin(data[y_col]), len(atoms_list) - 1, } ) ) for i in indices: if i == 0: xybox, rad = draw_reactant[:2], draw_reactant[2] elif i == len(atoms_list) - 1: xybox, rad = draw_product[:2], draw_product[2] else: xybox, rad = draw_saddle[:2], draw_saddle[2] if plot_structures == "all": xybox = (15.0, 60.0 if i % 2 == 0 else -60.0) rad = 0.1 if i % 2 == 0 else -0.1 # Call library function plot_structure_inset( ax, atoms_list[i], data[1][i], data[y_col][i], xybox, rad, zoom=zoom_ratio, rotation=ase_rotation, renderer=strip_renderer, ) # --- Profile Additional Structures --- if additional_atoms_data and rc_mode == "rmsd": for i, (add_atoms, add_r, _) in enumerate(additional_atoms_data): ax.axvline( add_r, color=active_theme.gridcolor, linestyle=":", linewidth=2, zorder=90, ) if plot_structures != "none": y_span = ax.get_ylim()[1] - ax.get_ylim()[0] y_pos = ax.get_ylim()[0] + 0.9 * y_span plot_structure_inset( ax, add_atoms, add_r, y_pos, xybox=( draw_saddle[0] + (i * 15), draw_saddle[1], ), # Stagger slightly rad=draw_saddle[2], zoom=zoom_ratio, rotation=ase_rotation, renderer=strip_renderer, arrow_props={ "arrowstyle": ArrowStyle.Fancy( head_length=arrow_head_length, head_width=arrow_head_width, tail_width=arrow_tail_width, ), "connectionstyle": f"arc3,rad={rad}", "linestyle": "-", "alpha": 0.8, "color": "black", "linewidth": 1.2, }, ) # Profile Labels final_xlabel = xlabel or ( r"RMSD ($\AA$)" if rc_mode == "rmsd" else r"Reaction Coordinate ($\AA$)" ) final_ylabel = ylabel or "Relative Energy (eV)" final_title = title # Final Aesthetics ax.set_xlabel(final_xlabel, weight="bold") ax.set_ylabel(final_ylabel, weight="bold") ax.set_title(final_title) ax.minorticks_on() if plot_type == "landscape" and not aspect_ratio: ax.set_aspect("equal") if project_path: # Force Y-axis to be symmetric and match the X-axis span, # but expand if additional structures fall outside x_min, x_max = ax.get_xlim() x_span = x_max - x_min half_span = x_span / 2 if additional_atoms_data: for _, add_r, add_p, _ in additional_atoms_data: add_d = (add_r - r_start) * v_r + (add_p - p_start) * v_p half_span = max(half_span, abs(add_d) * 1.15) ax.set_ylim(-half_span, half_span) log.info(f"Set symmetric Y-axis limits: [-{half_span:.2f}, {half_span:.2f}]") if show_legend: ax.legend( # In (s,d) space markers can appear anywhere, so let # matplotlib pick the least-overlapping corner. # In raw RMSD-RMSD the lower left is always empty. loc="best" if project_path else "lower left", borderaxespad=0.5, frameon=True, framealpha=1.0, facecolor="white", edgecolor="black", fontsize=int(active_theme.font_size * 0.8), ).set_zorder(101) plt.tight_layout(pad=0.5) if ax_strip: pos_main = ax.get_position() pos_strip = ax_strip.get_position() # Force strip to match the main plot's Left and Width exactly, # and push it down slightly to avoid overlapping the xlabel strip_y = pos_strip.y0 - 0.02 ax_strip.set_position([pos_main.x0, strip_y, pos_main.width, pos_strip.height]) if output_file: plt.savefig(output_file, transparent=False, bbox_inches="tight", dpi=dpi) else: plt.show()
if __name__ == "__main__": main()