Surface Fitting with Kernel Methods¶
The rgpycrumbs.surfaces module provides high-performance, differentiable surface models using JAX. These models are designed for constructing Potential Energy Surfaces (PES) and other multidimensional functions, supporting both standard observations and gradient-enhanced data.
Available Models¶
A variety of kernel-based models are provided via a unified interface:
TPS/RBF: Thin Plate Splines, standard for smooth interpolation.
Matérn 5/2: A stationary kernel with finite differentiability.
IMQ: Inverse Multi-Quadratic kernel, often more stable for large spans.
Gradient-Enhanced: Versions of Matérn, Squared Exponential (SE), RQ, and IMQ that incorporate energy gradients directly into the covariance structure.
Nystrom Gradient IMQ: A memory-efficient approximation for large datasets.
Uncertainty and Variance¶
Every model in surfaces.py implements a predict_var(x_query) method. This calculates the posterior predictive variance, providing a measure of uncertainty:
Interpretation: The variance represents the model’s confidence. At training points (and within the kernel’s length scale), the variance is low. In data-sparse regions, it reverts toward the prior variance (e.g., \(1/\epsilon\) for IMQ or \(1.0\) for Matérn).
Optimization: Parameters like length scales (\(l\), \(\epsilon\)) and noise scalars are optimized via Maximum Likelihood Estimation (MLE) or Maximum A Posteriori (MAP) with physically-informed priors.
Variance Windowing and Chunking¶
To prevent Out-Of-Memory (OOM) errors during the evaluation of large grids (e.g., 2D slice visualizations), both the prediction and variance calls are internally chunked.
model = GradientMatern(x_obs, y_obs, gradients=g_obs)
# Evaluates in chunks of 500 query points by default
z_preds = model(x_grid, chunk_size=1000)
z_vars = model.predict_var(x_grid, chunk_size=1000)
This “windowed” evaluation ensures that the large \(N_{query} imes N_{train}\) cross-covariance matrices do not exhaust system memory.
Numerical Stability¶
Gradient-enhanced models use auto-differentiation (via jax.grad and jax.jacfwd) to construct the full \((D+1)N imes (D+1)N\) covariance matrix.
Cholesky Fallback: The
safe_cholesky_solveutility attempts Cholesky decomposition with increasing jitter to handle near-singular matrices.Float Precision: By default, models use
float32for performance and compatibility with visualization backends, butjax_enable_x64can be toggled if higher precision is required.
API Reference¶
For detailed method signatures, see the :doc:`Surfaces API Reference <../../autoapi/rgpycrumbs/surfaces/index>`.