# SPDX-FileCopyrightText: 2023-present Rohit Goswami <rog32@hi.is>
#
# SPDX-License-Identifier: MIT
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "click",
# "h5py",
# "pandas",
# "plotnine",
# "chemparseplot",
# ]
# ///
"""ChemGP CLI - Plot generation from HDF5 data.
Thin CLI wrapper around chemgp.plotting functions.
All plotting logic delegated to pure functions.
.. versionadded:: 1.7.0
Refactored from chemgp.plt_gp to thin CLI wrapper.
"""
import logging
from pathlib import Path
import click
import h5py
import numpy as np
from chemparseplot.parse.chemgp_hdf5 import (
read_h5_grid,
read_h5_metadata,
read_h5_path,
read_h5_points,
read_h5_table,
)
from chemparseplot.plot.chemgp import (
detect_clamp,
plot_convergence_curve,
plot_energy_profile,
plot_fps_projection,
plot_gp_progression,
plot_hyperparameter_sensitivity,
plot_nll_landscape,
plot_rff_quality,
plot_surface_contour,
plot_trust_region,
plot_variance_overlay,
save_plot,
)
# --- Logging ---
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s - %(message)s",
)
[docs]
log = logging.getLogger(__name__)
# --- Common click options ---
[docs]
def common_options(func):
"""Shared options for all subcommands."""
func = click.option(
"--input",
"-i",
"input_path",
required=True,
type=click.Path(exists=True, path_type=Path),
help="HDF5 data file.",
)(func)
func = click.option(
"--output",
"-o",
"output_path",
required=True,
type=click.Path(path_type=Path),
help="Output PDF path.",
)(func)
func = click.option(
"--width",
"-W",
default=7.0,
type=float,
help="Figure width in inches.",
)(func)
func = click.option(
"--height",
"-H",
default=5.0,
type=float,
help="Figure height in inches.",
)(func)
func = click.option(
"--dpi",
default=300,
type=int,
help="Output resolution.",
)(func)
return func
# --- CLI ---
@click.group()
[docs]
def cli():
"""ChemGP figure generation from HDF5 data."""
pass
@cli.command()
@common_options
[docs]
def convergence(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Force/energy convergence vs oracle calls."""
with h5py.File(input_path, "r") as f:
df = read_h5_table(f, "table")
meta = read_h5_metadata(f)
conv_tol = meta.get("conv_tol", None)
# Auto-detect y column
y = "force_norm"
for candidate in ["ci_force", "max_fatom", "max_force"]:
if candidate in df.columns:
y = candidate
break
fig = plot_convergence_curve(
df,
y=y,
conv_tol=float(conv_tol) if conv_tol is not None else None,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
@click.option("--clamp-lo", default=None, type=float)
@click.option("--clamp-hi", default=None, type=float)
@click.option("--contour-step", default=None, type=float)
[docs]
def surface(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
clamp_lo: float | None,
clamp_hi: float | None,
contour_step: float | None,
):
"""2D PES contour plot."""
# Auto-detect clamping from filename if not specified
if clamp_lo is None and clamp_hi is None:
clamp_lo, clamp_hi, contour_step = detect_clamp(input_path.name)
with h5py.File(input_path, "r") as f:
data, xc, yc = read_h5_grid(f, "energy")
# Collect paths
paths = None
if "paths" in f:
paths = {}
for pname in f["paths"].keys():
pdata = read_h5_path(f, pname)
keys = list(pdata.keys())
paths[pname] = (pdata[keys[0]], pdata[keys[1]])
# Collect points
points = None
if "points" in f:
points = {}
for pname in f["points"].keys():
pdata = read_h5_points(f, pname)
keys = list(pdata.keys())
points[pname] = (pdata[keys[0]], pdata[keys[1]])
# Build meshgrid from coordinates
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = data.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
levels = None
if clamp_lo is not None and clamp_hi is not None:
levels = np.linspace(clamp_lo, clamp_hi, 25)
fig = plot_surface_contour(
gx,
gy,
data,
paths=paths,
points=points,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
levels=levels,
contour_step=contour_step,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
@click.option("--n-points", multiple=True, type=int, default=None)
[docs]
def quality(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
n_points: tuple[int, ...] | None,
):
"""GP surrogate quality progression (multi-panel)."""
# Auto-detect clamping from filename
clamp_lo, clamp_hi, _ = detect_clamp(input_path.name)
if clamp_lo is None:
clamp_lo = -200.0
if clamp_hi is None:
clamp_hi = 50.0
with h5py.File(input_path, "r") as f:
true_e, xc, yc = read_h5_grid(f, "true_energy")
# Auto-detect n values from grid names if not specified
if not n_points:
grid_names = [k for k in f["grids"].keys() if k.startswith("gp_mean_N")]
n_points = sorted(int(k.replace("gp_mean_N", "")) for k in grid_names)
grids = {}
for n in n_points:
gp_e, _, _ = read_h5_grid(f, f"gp_mean_N{n}")
entry = {"gp_mean": gp_e}
# Read training points if available
pts_name = f"train_N{n}"
if "points" in f and pts_name in f["points"]:
pts = read_h5_points(f, pts_name)
keys = list(pts.keys())
entry["train_x"] = pts[keys[0]]
entry["train_y"] = pts[keys[1]]
grids[n] = entry
fig = plot_gp_progression(
grids,
true_e,
xc,
yc,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def rff(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""RFF approximation quality vs exact GP."""
with h5py.File(input_path, "r") as f:
df = read_h5_table(f, "table")
meta = read_h5_metadata(f)
rename_map = {}
if "energy_mae_vs_gp" in df.columns:
rename_map["energy_mae_vs_gp"] = "energy_mae"
if "gradient_mae_vs_gp" in df.columns:
rename_map["gradient_mae_vs_gp"] = "gradient_mae"
if "D_rff" in df.columns:
rename_map["D_rff"] = "d_rff"
if rename_map:
df = df.rename(columns=rename_map)
exact_e = float(meta.get("gp_e_mae", 0.0))
exact_g = float(meta.get("gp_g_mae", 0.0))
fig = plot_rff_quality(
df,
exact_e_mae=exact_e,
exact_g_mae=exact_g,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def nll(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""MAP-NLL landscape in hyperparameter space."""
with h5py.File(input_path, "r") as f:
nll_data, xc, yc = read_h5_grid(f, "nll")
opt = read_h5_points(f, "optimum")
# Read gradient norm grid if available
grad_norm = None
if "grids" in f and "grad_norm" in f["grids"]:
grad_norm, _, _ = read_h5_grid(f, "grad_norm")
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = nll_data.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
optimum = None
if "log_sigma2" in opt and "log_theta" in opt:
optimum = (float(opt["log_sigma2"][0]), float(opt["log_theta"][0]))
fig = plot_nll_landscape(
gx,
gy,
nll_data,
grid_grad_norm=grad_norm,
optimum=optimum,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def sensitivity(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Hyperparameter sensitivity grid (3x3)."""
with h5py.File(input_path, "r") as f:
slice_df = read_h5_table(f, "slice")
true_df = read_h5_table(f, "true_surface")
x_vals = slice_df["x"].to_numpy()
y_true = true_df["E_true"].to_numpy()
panels = {}
for j in range(1, 4):
for i in range(1, 4):
name = f"gp_ls{j}_sv{i}"
if name in f:
gp_df = read_h5_table(f, name)
panels[name] = {
"E_pred": gp_df["E_pred"].to_numpy(),
"E_std": gp_df["E_std"].to_numpy(),
}
fig = plot_hyperparameter_sensitivity(
x_vals,
y_true,
panels,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def trust(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""Trust region illustration (1D slice)."""
with h5py.File(input_path, "r") as f:
slice_df = read_h5_table(f, "slice")
training = read_h5_points(f, "training")
meta = read_h5_metadata(f)
x_slice = slice_df["x"].to_numpy()
e_true = slice_df["E_true"].to_numpy()
e_pred = slice_df["E_pred"].to_numpy()
e_std = slice_df["E_std"].to_numpy()
in_trust = slice_df["in_trust"].to_numpy()
# Training x coordinates (filter to nearby slice)
y_slice = float(meta.get("y_slice", 0.5))
train_x = training.get("x", np.array([]))
train_y = training.get("y", np.array([]))
if len(train_x) > 0 and len(train_y) > 0:
mask = np.abs(train_y - y_slice) < 0.3
train_x = train_x[mask]
else:
train_x = None
fig = plot_trust_region(
x_slice,
e_true,
e_pred,
e_std,
in_trust,
train_x=train_x,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def variance(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""GP variance overlaid on PES."""
# Auto-detect clamping from filename
clamp_lo, clamp_hi, _ = detect_clamp(input_path.name)
if clamp_lo is None:
clamp_lo = -200.0
if clamp_hi is None:
clamp_hi = 50.0
with h5py.File(input_path, "r") as f:
energy, xc, yc = read_h5_grid(f, "energy")
var_data, _, _ = read_h5_grid(f, "variance")
training = read_h5_points(f, "training")
minima = None
if "points" in f and "minima" in f["points"]:
minima = read_h5_points(f, "minima")
saddles = None
if "points" in f and "saddles" in f["points"]:
saddles = read_h5_points(f, "saddles")
if xc is not None and yc is not None:
gx, gy = np.meshgrid(xc, yc)
else:
ny, nx = energy.shape
gx, gy = np.meshgrid(np.arange(nx), np.arange(ny))
# Build stationary points dict
stationary = {}
if minima is not None:
keys = list(minima.keys())
for idx in range(len(minima[keys[0]])):
stationary[f"min{idx}"] = (
float(minima[keys[0]][idx]),
float(minima[keys[1]][idx]),
)
if saddles is not None:
keys = list(saddles.keys())
for idx in range(len(saddles[keys[0]])):
stationary[f"saddle{idx}"] = (
float(saddles[keys[0]][idx]),
float(saddles[keys[1]][idx]),
)
train_pts = None
if training:
keys = list(training.keys())
train_pts = (training[keys[0]], training[keys[1]])
fig = plot_variance_overlay(
gx,
gy,
energy,
var_data,
train_points=train_pts,
stationary=stationary if stationary else None,
clamp_lo=clamp_lo,
clamp_hi=clamp_hi,
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def fps(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""FPS subset visualization (PCA scatter)."""
with h5py.File(input_path, "r") as f:
selected = read_h5_points(f, "selected")
pruned = read_h5_points(f, "pruned")
fig = plot_fps_projection(
selected["pc1"],
selected["pc2"],
pruned["pc1"],
pruned["pc2"],
width=width,
height=height,
)
save_plot(fig, output_path, dpi)
@cli.command()
@common_options
[docs]
def profile(
input_path: Path,
output_path: Path,
width: float,
height: float,
dpi: int,
):
"""NEB energy profile (image index vs delta E)."""
with h5py.File(input_path, "r") as f:
df = read_h5_table(f, "table")
fig = plot_energy_profile(df, width=width, height=height)
save_plot(fig, output_path, dpi)
@cli.command()
@click.option(
"--config",
"-c",
"config_path",
required=True,
type=click.Path(exists=True, path_type=Path),
help="TOML config listing plots to generate.",
)
@click.option(
"--base-dir",
"-b",
"base_dir",
default=None,
type=click.Path(path_type=Path),
help="Base directory for relative paths in config.",
)
@click.option("--dpi", default=300, type=int, help="Output resolution.")
@click.option(
"--parallel",
"-j",
default=1,
type=int,
help="Number of parallel jobs (default: 1).",
)
[docs]
def batch(
config_path: Path,
base_dir: Path | None,
dpi: int,
parallel: int,
):
"""Generate multiple plots from a TOML config."""
import tomllib
from concurrent.futures import ThreadPoolExecutor, as_completed
with open(config_path, "rb") as fp:
cfg = tomllib.load(fp)
if base_dir is None:
base_dir = config_path.parent
defaults = cfg.get("defaults", {})
input_dir = base_dir / defaults.get("input_dir", ".")
output_dir = base_dir / defaults.get("output_dir", ".")
plots = cfg.get("plots", [])
if not plots:
log.warning("No [[plots]] entries in %s", config_path)
return
# Map plot types to functions
cmds = {
"convergence": convergence,
"surface": surface,
"quality": quality,
"rff": rff,
"nll": nll,
"sensitivity": sensitivity,
"trust": trust,
"variance": variance,
"fps": fps,
"profile": profile,
}
def generate_single_plot(entry: dict) -> tuple[str, bool, str | None]:
"""Generate a single plot. Returns (output_name, success, error_msg)."""
plot_type = entry.get("type")
if plot_type not in cmds:
return entry.get("output", "unknown"), False, f"Unknown type: {plot_type}"
out = output_dir / entry["output"]
w = entry.get("width", 7.0)
h = entry.get("height", 5.0)
d = entry.get("dpi", dpi)
# Build arguments based on plot type
if plot_type == "landscape":
src_dir = base_dir / entry.get("source_dir", ".")
args = [
"--source-dir",
str(src_dir),
"--output",
str(out),
"--width",
str(w),
"--height",
str(h),
"--dpi",
str(d),
]
else:
inp = input_dir / entry["input"]
if not inp.exists():
return entry["output"], False, f"Input not found: {inp}"
args = [
"--input",
str(inp),
"--output",
str(out),
"--width",
str(w),
"--height",
str(h),
"--dpi",
str(d),
]
# Forward extra keys as CLI options
skip = {"type", "input", "output", "width", "height", "dpi", "source_dir"}
for k, v in entry.items():
if k in skip:
continue
flag = f"--{k.replace('_', '-')}"
if isinstance(v, bool):
if v:
args.append(flag)
elif isinstance(v, list):
for item in v:
if isinstance(item, list):
args.append(flag)
args.extend(str(x) for x in item)
else:
args.extend([flag, str(item)])
else:
args.extend([flag, str(v)])
try:
from click.testing import CliRunner
runner = CliRunner()
result = runner.invoke(cmds[plot_type], args)
if result.exit_code == 0:
return entry["output"], True, None
else:
return entry["output"], False, result.output
except Exception as e:
return entry["output"], False, str(e)
# Process plots
n_ok = 0
n_fail = 0
if parallel > 1:
# Parallel processing with progress tracking
from rich.progress import Progress
with Progress() as progress:
task = progress.add_task("[cyan]Generating plots...", total=len(plots))
with ThreadPoolExecutor(max_workers=parallel) as executor:
futures = {
executor.submit(generate_single_plot, entry): entry for entry in plots
}
for future in as_completed(futures):
entry = futures[future]
try:
out_name, success, error = future.result()
if success:
n_ok += 1
log.info("[green][OK][/green] %s", out_name)
else:
n_fail += 1
log.error("[red][FAIL][/red] %s: %s", out_name, error)
except Exception as e:
n_fail += 1
log.error(
"[red][FAIL][/red] %s: %s", entry.get("output", "unknown"), e
)
progress.advance(task)
else:
# Sequential processing
for entry in plots:
out_name, success, error = generate_single_plot(entry)
if success:
n_ok += 1
log.info("[OK] %s", out_name)
else:
n_fail += 1
log.error("[FAIL] %s: %s", out_name, error)
log.info("Batch complete: %d ok, %d failed", n_ok, n_fail)
if n_fail > 0:
import sys
sys.exit(1)
[docs]
def main():
"""CLI entry point."""
cli()
if __name__ == "__main__":
main()