#!/usr/bin/env python3
"""Plot ChemGP figures from HDF5 data files.
.. versionadded:: 1.6.0
CLI for generating publication figures from ChemGP HDF5 outputs.
Reads grids, tables, paths, and point sets, then delegates to
chemparseplot for visualization.
HDF5 layout (mirrors Julia common_plot.jl helpers):
- ``grids/<name>``: 2D arrays with attrs x_range, y_range,
x_length, y_length
- ``table/<name>``: group of same-length 1D arrays
- ``paths/<name>``: point sequences (x, y or rAB, rBC)
- ``points/<name>``: point sets (x, y or pc1, pc2)
- Root attrs: metadata scalars
"""
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "click",
# "h5py",
# "pandas",
# "plotnine",
# "chemparseplot",
# "rgpycrumbs",
# ]
# ///
import logging
import subprocess
import sys
from pathlib import Path
import click
import h5py
import numpy as np
import pandas as pd
from chemparseplot.plot.chemgp import (
plot_convergence_curve,
plot_energy_profile,
plot_fps_projection,
plot_gp_progression,
plot_hyperparameter_sensitivity,
plot_nll_landscape,
plot_rff_quality,
plot_surface_contour,
plot_trust_region,
plot_variance_overlay,
)
# --- Logging ---
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s - %(message)s",
)
[docs]
log = logging.getLogger(__name__)
# --- HDF5 helpers ---
[docs]
def h5_read_table(f: h5py.File, name: str = "table") -> pd.DataFrame:
"""Read a group of same-length vectors as a DataFrame."""
g = f[name]
cols = {}
for k in g.keys():
arr = g[k][()]
if arr.dtype.kind in {"S", "O"}:
cols[k] = arr.astype(str).tolist()
else:
cols[k] = arr.tolist()
return pd.DataFrame(cols)
[docs]
def h5_read_grid(
f: h5py.File, name: str
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
"""Read a 2D grid with optional axis ranges.
Returns (data, x_coords, y_coords).
"""
ds = f[f"grids/{name}"]
data = ds[()]
x_coords = None
y_coords = None
if "x_range" in ds.attrs and "x_length" in ds.attrs:
lo, hi = ds.attrs["x_range"]
n = int(ds.attrs["x_length"])
x_coords = np.linspace(lo, hi, n)
if "y_range" in ds.attrs and "y_length" in ds.attrs:
lo, hi = ds.attrs["y_range"]
n = int(ds.attrs["y_length"])
y_coords = np.linspace(lo, hi, n)
return data, x_coords, y_coords
[docs]
def h5_read_path(f: h5py.File, name: str) -> dict[str, np.ndarray]:
"""Read a path (ordered point sequence)."""
g = f[f"paths/{name}"]
return {k: g[k][()] for k in g.keys()}
[docs]
def h5_read_points(f: h5py.File, name: str) -> dict[str, np.ndarray]:
"""Read a point set."""
g = f[f"points/{name}"]
return {k: g[k][()] for k in g.keys()}
[docs]
def save_plot(fig, output: Path, dpi: int) -> None:
"""Save a plotnine ggplot or matplotlib Figure to PDF."""
import matplotlib.pyplot as plt # noqa: PLC0415
output.parent.mkdir(parents=True, exist_ok=True)
if isinstance(fig, plt.Figure):
fig.savefig(str(output), dpi=dpi, bbox_inches="tight")
plt.close(fig)
else:
# plotnine ggplot
fig.save(str(output), dpi=dpi, verbose=False)
log.info("Saved: %s", output)
# --- Auto-detect clamping from filename ---
# MB surfaces need [-200, 50], LEPS need [-5, 5]
[docs]
_CLAMP_PRESETS = {
"mb": (-200.0, 50.0, 25.0), # (lo, hi, contour_step)
"leps": (-5.0, 5.0, 0.5),
}
[docs]
def _detect_clamp(filename: str) -> tuple[float | None, float | None, float | None]:
"""Detect energy clamping preset from filename."""
stem = filename.lower()
for prefix, (lo, hi, step) in _CLAMP_PRESETS.items():
if prefix in stem:
return lo, hi, step
return None, None, None
# --- Common click options ---
[docs]
def common_options(func):
"""Shared options for all subcommands."""
func = click.option(
"--input",
"-i",
"input_path",
required=True,
type=click.Path(exists=True, path_type=Path),
help="HDF5 data file.",
)(func)
func = click.option(
"--output",
"-o",
"output_path",
required=True,
type=click.Path(path_type=Path),
help="Output PDF path.",
)(func)
func = click.option(
"--width",
"-W",
default=7.0,
type=float,
help="Figure width in inches.",
)(func)
func = click.option(
"--height",
"-H",
default=5.0,
type=float,
help="Figure height in inches.",
)(func)
func = click.option(
"--dpi",
default=300,
type=int,
help="Output resolution.",
)(func)
return func
# --- CLI ---
@click.group()
[docs]
def cli():
"""ChemGP figure generation from HDF5 data."""
@cli.command()
@common_options
[docs]
def convergence(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Force/energy convergence vs oracle calls."""
with h5py.File(input_path, "r") as f:
df = h5_read_table(f, "table")
meta = h5_read_metadata(f)
# Detect y column: prefer ci_force > max_fatom > max_force > force_norm
y_col = "force_norm"
for candidate in ["ci_force", "max_fatom", "max_force"]:
if candidate in df.columns:
y_col = candidate
break
conv_tol = meta.get("conv_tol", None)
fig = plot_convergence_curve(
df,
x="oracle_calls",
y=y_col,
color="method",
conv_tol=float(conv_tol) if conv_tol is not None else None,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
@click.option("--clamp-lo", default=None, type=float)
@click.option("--clamp-hi", default=None, type=float)
@click.option("--contour-step", default=None, type=float)
[docs]
def surface(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
clamp_lo: float | None,
clamp_hi: float | None,
contour_step: float | None,
):
"""2D PES contour plot."""
# Auto-detect clamping from filename if not specified
if clamp_lo is None and clamp_hi is None:
clamp_lo, clamp_hi, contour_step = _detect_clamp(input_path.name)
with h5py.File(input_path, "r") as f:
data, xc, yc = h5_read_grid(f, "energy")
# Collect paths
paths = None
if "paths" in f:
paths = {}
for pname in f["paths"].keys():
pdata = h5_read_path(f, pname)
keys = list(pdata.keys())
paths[pname] = (pdata[keys[0]], pdata[keys[1]])
# Collect points
points = None
if "points" in f:
points = {}
for pname in f["points"].keys():
pdata = h5_read_points(f, pname)
keys = list(pdata.keys())
points[pname] = (pdata[keys[0]], pdata[keys[1]])
# Build meshgrid from coordinates
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = data.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
# Build explicit levels if clamping specified
levels = None
if clamp_lo is not None and clamp_hi is not None:
levels = np.linspace(clamp_lo, clamp_hi, 25)
fig = plot_surface_contour(
gx,
gy,
data,
paths=paths,
points=points,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
levels=levels,
contour_step=contour_step,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
@click.option("--n-points", multiple=True, type=int, default=None)
[docs]
def quality(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
n_points: tuple[int, ...] | None,
):
"""GP surrogate quality progression (multi-panel)."""
# Auto-detect clamping from filename
clamp_lo, clamp_hi, _ = _detect_clamp(input_path.name)
if clamp_lo is None:
clamp_lo = -200.0
if clamp_hi is None:
clamp_hi = 50.0
with h5py.File(input_path, "r") as f:
true_e, xc, yc = h5_read_grid(f, "true_energy")
# Auto-detect n values from grid names if not specified
if not n_points:
grid_names = [k for k in f["grids"].keys() if k.startswith("gp_mean_N")]
n_points = sorted(int(k.replace("gp_mean_N", "")) for k in grid_names)
grids = {}
for n in n_points:
gp_e, _, _ = h5_read_grid(f, f"gp_mean_N{n}")
entry = {"gp_mean": gp_e}
# Read training points if available
pts_name = f"train_N{n}"
if "points" in f and pts_name in f["points"]:
pts = h5_read_points(f, pts_name)
keys = list(pts.keys())
entry["train_x"] = pts[keys[0]]
entry["train_y"] = pts[keys[1]]
grids[n] = entry
fig = plot_gp_progression(
grids,
true_e,
xc,
yc,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def rff(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""RFF approximation quality vs exact GP."""
with h5py.File(input_path, "r") as f:
df = h5_read_table(f, "table")
meta = h5_read_metadata(f)
rename_map = {}
if "energy_mae_vs_gp" in df.columns:
rename_map["energy_mae_vs_gp"] = "energy_mae"
if "gradient_mae_vs_gp" in df.columns:
rename_map["gradient_mae_vs_gp"] = "gradient_mae"
if "D_rff" in df.columns:
rename_map["D_rff"] = "d_rff"
if rename_map:
df = df.rename(columns=rename_map)
exact_e = float(meta.get("gp_e_mae", 0.0))
exact_g = float(meta.get("gp_g_mae", 0.0))
fig = plot_rff_quality(
df,
exact_e_mae=exact_e,
exact_g_mae=exact_g,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def nll(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""MAP-NLL landscape in hyperparameter space."""
with h5py.File(input_path, "r") as f:
nll_data, xc, yc = h5_read_grid(f, "nll")
opt = h5_read_points(f, "optimum")
# Read gradient norm grid if available
grad_norm = None
if "grids" in f and "grad_norm" in f["grids"]:
grad_norm, _, _ = h5_read_grid(f, "grad_norm")
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = nll_data.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
optimum = None
if "log_sigma2" in opt and "log_theta" in opt:
optimum = (float(opt["log_sigma2"][0]), float(opt["log_theta"][0]))
fig = plot_nll_landscape(
gx,
gy,
nll_data,
grid_grad_norm=grad_norm,
optimum=optimum,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def sensitivity(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Hyperparameter sensitivity grid (3x3)."""
with h5py.File(input_path, "r") as f:
slice_df = h5_read_table(f, "slice")
true_df = h5_read_table(f, "true_surface")
x_vals = slice_df["x"].to_numpy()
y_true = true_df["E_true"].to_numpy()
panels = {}
for j in range(1, 4):
for i in range(1, 4):
name = f"gp_ls{j}_sv{i}"
if name in f:
gp_df = h5_read_table(f, name)
panels[name] = {
"E_pred": gp_df["E_pred"].to_numpy(),
"E_std": gp_df["E_std"].to_numpy(),
}
fig = plot_hyperparameter_sensitivity(
x_vals,
y_true,
panels,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def trust(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Trust region illustration (1D slice)."""
with h5py.File(input_path, "r") as f:
slice_df = h5_read_table(f, "slice")
training = h5_read_points(f, "training")
meta = h5_read_metadata(f)
x_slice = slice_df["x"].to_numpy()
e_true = slice_df["E_true"].to_numpy()
e_pred = slice_df["E_pred"].to_numpy()
e_std = slice_df["E_std"].to_numpy()
in_trust = slice_df["in_trust"].to_numpy()
# Training x coordinates (filter to nearby slice)
y_slice = float(meta.get("y_slice", 0.5))
train_x = training.get("x", np.array([]))
train_y = training.get("y", np.array([]))
# Keep only training points near the slice
if len(train_x) > 0 and len(train_y) > 0:
mask = np.abs(train_y - y_slice) < 0.3
train_x = train_x[mask]
else:
train_x = None
fig = plot_trust_region(
x_slice,
e_true,
e_pred,
e_std,
in_trust,
train_x=train_x,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def variance(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""GP variance overlaid on PES."""
# Auto-detect clamping from filename
clamp_lo, clamp_hi, _ = _detect_clamp(input_path.name)
if clamp_lo is None:
clamp_lo = -200.0
if clamp_hi is None:
clamp_hi = 50.0
with h5py.File(input_path, "r") as f:
energy, xc, yc = h5_read_grid(f, "energy")
var_data, _, _ = h5_read_grid(f, "variance")
training = h5_read_points(f, "training")
minima = None
if "points" in f and "minima" in f["points"]:
minima = h5_read_points(f, "minima")
saddles = None
if "points" in f and "saddles" in f["points"]:
saddles = h5_read_points(f, "saddles")
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = energy.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
# Build stationary points dict
stationary = {}
if minima is not None:
keys = list(minima.keys())
for idx in range(len(minima[keys[0]])):
stationary[f"min{idx}"] = (
float(minima[keys[0]][idx]),
float(minima[keys[1]][idx]),
)
if saddles is not None:
keys = list(saddles.keys())
for idx in range(len(saddles[keys[0]])):
stationary[f"saddle{idx}"] = (
float(saddles[keys[0]][idx]),
float(saddles[keys[1]][idx]),
)
train_pts = None
if training:
keys = list(training.keys())
train_pts = (training[keys[0]], training[keys[1]])
fig = plot_variance_overlay(
gx,
gy,
energy,
var_data,
train_points=train_pts,
stationary=stationary if stationary else None,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def fps(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""FPS subset visualization (PCA scatter)."""
with h5py.File(input_path, "r") as f:
selected = h5_read_points(f, "selected")
pruned = h5_read_points(f, "pruned")
fig = plot_fps_projection(
selected["pc1"],
selected["pc2"],
pruned["pc1"],
pruned["pc2"],
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def profile(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""NEB energy profile (image index vs delta E)."""
with h5py.File(input_path, "r") as f:
df = h5_read_table(f, "table")
fig = plot_energy_profile(
df,
x="image",
y="energy",
color="method",
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@click.option(
"--dat-pattern",
default=None,
type=str,
help="Glob pattern for .dat files (eON format).",
)
@click.option(
"--con-pattern",
default=None,
type=str,
help="Glob pattern for .con path files (eON format).",
)
@click.option(
"--source-dir",
"-d",
default=".",
type=click.Path(exists=True, path_type=Path),
help="Directory containing .dat/.con files.",
)
@click.option(
"--output",
"-o",
"output_path",
required=True,
type=click.Path(path_type=Path),
help="Output PDF path.",
)
@click.option("--width", "-W", default=7.0, type=float)
@click.option("--height", "-H", default=5.0, type=float)
@click.option("--dpi", default=300, type=int)
@click.option(
"--surface-type",
default="grad_imq",
type=str,
help="Surface interpolation method for plt-neb.",
)
@click.option(
"--landscape-mode",
default="surface",
type=click.Choice(["path", "surface"]),
)
@click.option("--plot-structures", default="none", type=str)
@click.option("--project-path/--no-project-path", is_flag=True, default=None)
@click.option(
"--additional-con",
type=(str, str),
multiple=True,
default=None,
help="(path, label) pairs for extra path overlays.",
)
@click.option(
"--augment-dat", type=str, default=None, help="Glob for extra .dat surface data."
)
@click.option(
"--augment-con", type=str, default=None, help="Glob for extra .con surface data."
)
@click.option(
"--zoom-ratio", type=float, default=None, help="Structure image zoom ratio."
)
@click.option(
"--ase-rotation", type=str, default=None, help="ASE rotation string for structures."
)
[docs]
def landscape(
dat_pattern: str | None,
con_pattern: str | None,
source_dir: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
surface_type: str,
landscape_mode: str,
plot_structures: str,
project_path: bool,
additional_con: tuple | None,
augment_dat: str | None,
augment_con: str | None,
zoom_ratio: float | None,
ase_rotation: str | None,
):
"""2D NEB reaction landscape via plt-neb (RMSD coordinates)."""
# Delegate to plt-neb which has the full landscape pipeline
plt_neb_script = Path(__file__).parent.parent / "eon" / "plt_neb.py"
cmd = [
sys.executable,
str(plt_neb_script),
"--source",
"eon",
"--plot-type",
"landscape",
"--landscape-mode",
landscape_mode,
"--surface-type",
surface_type,
"--plot-structures",
plot_structures,
"--figsize",
str(width),
str(height),
"--dpi",
str(dpi),
"--output-file",
str(output_path),
]
if dat_pattern:
cmd.extend(["--input-dat-pattern", dat_pattern])
if con_pattern:
cmd.extend(["--input-path-pattern", con_pattern])
cmd.extend(["--con-file", con_pattern])
if project_path:
cmd.append("--project-path")
elif project_path is not None:
cmd.append("--no-project-path")
if additional_con:
cmd.append("--show-legend")
for con_path, con_label in additional_con:
cmd.extend(["--additional-con", con_path, con_label])
if augment_dat:
cmd.extend(["--augment-dat", augment_dat])
if augment_con:
cmd.extend(["--augment-con", augment_con])
if zoom_ratio is not None:
cmd.extend(["--zoom-ratio", str(zoom_ratio)])
if ase_rotation is not None:
cmd.extend(["--ase-rotation", ase_rotation])
log.info("Delegating to plt-neb: %s", " ".join(cmd))
result = subprocess.run( # noqa: S603
cmd, cwd=str(source_dir), capture_output=True, text=True, check=False
)
if result.returncode != 0:
log.error("plt-neb failed:\n%s\n%s", result.stdout, result.stderr)
raise click.ClickException("plt-neb landscape generation failed")
log.info("Saved: %s", output_path)
@cli.command()
@click.option(
"--config",
"-c",
"config_path",
required=True,
type=click.Path(exists=True, path_type=Path),
help="TOML config listing plots to generate.",
)
@click.option(
"--base-dir",
"-b",
"base_dir",
default=None,
type=click.Path(path_type=Path),
help="Base directory for relative paths in config.",
)
@click.option("--dpi", default=300, type=int, help="Output resolution.")
[docs]
def batch(
config_path: Path,
base_dir: Path | None,
dpi: int,
):
"""Generate multiple plots from a TOML config."""
try:
import tomllib # noqa: PLC0415
except ImportError:
import tomli as tomllib # type: ignore[no-redef] # noqa: PLC0415
with open(config_path, "rb") as fp:
cfg = tomllib.load(fp)
if base_dir is None:
base_dir = config_path.parent
defaults = cfg.get("defaults", {})
input_dir = base_dir / defaults.get("input_dir", ".")
output_dir = base_dir / defaults.get("output_dir", ".")
plots = cfg.get("plots", [])
if not plots:
log.warning("No [[plots]] entries in %s", config_path)
return
cmds = {
"convergence": convergence,
"surface": surface,
"quality": quality,
"rff": rff,
"nll": nll,
"sensitivity": sensitivity,
"trust": trust,
"variance": variance,
"fps": fps,
"profile": profile,
"landscape": landscape,
}
n_ok = 0
n_fail = 0
n_skip = 0
for idx, entry in enumerate(plots):
plot_type = entry.get("type")
if plot_type not in cmds:
log.error("Plot %d: unknown type %r, skipping", idx, plot_type)
n_fail += 1
continue
out = output_dir / entry["output"]
w = entry.get("width", 7.0)
h = entry.get("height", 5.0)
d = entry.get("dpi", dpi)
# Landscape type uses source_dir instead of input HDF5
if plot_type == "landscape":
src_dir = base_dir / entry.get("source_dir", ".")
inp_name = entry.get("source_dir", ".")
args = [
"--source-dir",
str(src_dir),
"--output",
str(out),
"--width",
str(w),
"--height",
str(h),
"--dpi",
str(d),
]
else:
inp = input_dir / entry["input"]
inp_name = inp.name
if not inp.exists():
log.warning("Skipping %s: input %s not found", entry["output"], inp)
n_skip += 1
continue
args = [
"--input",
str(inp),
"--output",
str(out),
"--width",
str(w),
"--height",
str(h),
"--dpi",
str(d),
]
# Forward extra keys as CLI options
skip = {"type", "input", "output", "width", "height", "dpi", "source_dir"}
for k, v in entry.items():
if k in skip:
continue
flag = f"--{k.replace('_', '-')}"
if isinstance(v, bool):
if v:
args.append(flag)
elif isinstance(v, list):
for item in v:
if isinstance(item, list):
# Nested list: e.g. additional_con = [["path", "label"], ...]
args.append(flag)
args.extend(str(x) for x in item)
else:
args.extend([flag, str(item)])
else:
args.extend([flag, str(v)])
log.info(
"[%d/%d] %s: %s -> %s",
idx + 1,
len(plots),
plot_type,
inp_name,
out.name,
)
try:
ctx = click.Context(cmds[plot_type])
cmds[plot_type].parse_args(ctx, args)
ctx.invoke(cmds[plot_type].callback, **ctx.params)
n_ok += 1
except Exception:
log.exception("Plot %d (%s) failed", idx, plot_type)
n_fail += 1
log.info("Batch complete: %d ok, %d skipped, %d failed", n_ok, n_skip, n_fail)
if n_fail > 0:
sys.exit(1)
if __name__ == "__main__":
main()