import jax
import jax.numpy as jnp
import jax.scipy.optimize as jopt
from jax import jit, vmap
# Force float32 for speed/viz
jax.config.update("jax_enable_x64", False)
# ==============================================================================
# HELPER: GENERIC LOSS FUNCTIONS
# ==============================================================================
[docs]
def safe_cholesky_solve(K, y, noise_scalar, jitter_steps=3):
"""Retries Cholesky with increasing jitter if it fails."""
N = K.shape[0]
# Try successively larger jitters: 1e-6, 1e-5, 1e-4
for i in range(jitter_steps):
jitter = (noise_scalar + 10 ** (-6 + i)) * jnp.eye(N)
try:
L = jnp.linalg.cholesky(K + jitter)
alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y))
log_det = 2.0 * jnp.sum(jnp.log(jnp.diag(L)))
return alpha, log_det
except:
continue
# Fallback for compilation safety (NaN propagation)
return jnp.zeros_like(y), jnp.nan
[docs]
def generic_negative_mll(K, y, noise_scalar):
alpha, log_det = safe_cholesky_solve(K, y, noise_scalar)
data_fit = 0.5 * jnp.dot(y.flatten(), alpha.flatten())
complexity = 0.5 * log_det
cost = data_fit + complexity
# heavy penalty if Cholesky failed (NaN)
return jnp.where(jnp.isnan(cost), 1e9, cost)
# ==============================================================================
# 1. TPS IMPLEMENTATION (Optimizable)
# ==============================================================================
@jit
[docs]
def _tps_kernel_matrix(x):
d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
K = r**2 * jnp.log(r)
return K
[docs]
def negative_mll_tps(log_params, x, y):
# TPS only really has a smoothing parameter to tune in this context
# (Length scale is inherent to the radial basis).
smoothing = jnp.exp(log_params[0])
K = _tps_kernel_matrix(x)
return generic_negative_mll(K, y, smoothing)
@jit
[docs]
def _tps_solve(x, y, sm):
K = _tps_kernel_matrix(x)
K = K + jnp.eye(x.shape[0]) * sm
# Polynomial Matrix
N = x.shape[0]
P = jnp.concatenate([jnp.ones((N, 1), dtype=jnp.float32), x], axis=1)
M = P.shape[1]
# Solve System
zeros = jnp.zeros((M, M), dtype=jnp.float32)
top = jnp.concatenate([K, P], axis=1)
bot = jnp.concatenate([P.T, zeros], axis=1)
lhs = jnp.concatenate([top, bot], axis=0)
rhs = jnp.concatenate([y, jnp.zeros(M, dtype=jnp.float32)])
coeffs = jnp.linalg.solve(lhs, rhs)
return coeffs[:N], coeffs[N:]
@jit
[docs]
def _tps_predict(x_query, x_obs, w, v):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
K_q = r**2 * jnp.log(r)
P_q = jnp.concatenate(
[jnp.ones((x_query.shape[0], 1), dtype=jnp.float32), x_query], axis=1
)
return K_q @ w + P_q @ v
[docs]
class FastTPS:
def __init__(self, x_obs, y_obs, smoothing=1e-3, optimize=True, **kwargs):
[docs]
self.x_obs = jnp.asarray(x_obs, dtype=jnp.float32)
[docs]
self.y_obs = jnp.asarray(y_obs, dtype=jnp.float32)
# TPS handles mean via polynomial, but centering helps optimization stability
[docs]
self.y_mean = jnp.mean(self.y_obs)
y_centered = self.y_obs - self.y_mean
init_sm = max(smoothing, 1e-4)
if optimize:
# Optimize [log_smoothing]
x0 = jnp.array([jnp.log(init_sm)])
def loss_fn(log_p):
return negative_mll_tps(log_p, self.x_obs, y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.sm = float(jnp.exp(results.x[0]))
if jnp.isnan(self.sm):
self.sm = init_sm
else:
self.sm = init_sm
self.w, self.v = _tps_solve(self.x_obs, self.y_obs, self.sm)
[docs]
def __call__(self, x_query, chunk_size=500):
"""
Batched prediction to prevent OOM errors on large grids.
"""
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_tps_predict(chunk, self.x_obs, self.w, self.v))
return jnp.concatenate(preds, axis=0)
# ==============================================================================
# 2. MATERN 5/2
# ==============================================================================
@jit
[docs]
def _matern_kernel_matrix(x, length_scale):
d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
# Matérn 5/2 Kernel
# k(r) = (1 + sqrt(5)r/l + 5r^2/3l^2) * exp(-sqrt(5)r/l)
sqrt5_r_l = jnp.sqrt(5.0) * r / length_scale
K = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * length_scale**2)) * jnp.exp(-sqrt5_r_l)
return K
[docs]
def negative_mll_matern_std(log_params, x, y):
length_scale = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K = _matern_kernel_matrix(x, length_scale)
return generic_negative_mll(K, y, noise_scalar)
@jit
[docs]
def _matern_solve(x, y, sm, length_scale):
K = _matern_kernel_matrix(x, length_scale)
K = K + jnp.eye(x.shape[0]) * sm
# 3. Solve (Cholesky is faster/stable for positive definite kernels like Matérn)
# Note: We don't use the polynomial 'P' matrix here usually, as Matérn
# decays to mean zero. If you want it to revert to a mean value, subtract
# mean(y) before fitting and add it back after.
L = jnp.linalg.cholesky(K)
alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y))
return alpha
@jit
[docs]
def _matern_predict(x_query, x_obs, alpha, length_scale):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
sqrt5_r_l = jnp.sqrt(5.0) * r / length_scale
K_q = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * length_scale**2)) * jnp.exp(
-sqrt5_r_l
)
return K_q @ alpha
[docs]
class FastMatern:
def __init__(
self, x_obs, y_obs, smoothing=1e-3, length_scale=None, optimize=True, **kwargs
):
[docs]
self.x_obs = jnp.asarray(x_obs, dtype=jnp.float32)
[docs]
self.y_obs = jnp.asarray(y_obs, dtype=jnp.float32)
# Center the data (important for stationary kernels like Matérn)
[docs]
self.y_mean = jnp.mean(self.y_obs)
y_centered = self.y_obs - self.y_mean
# Heuristic
if length_scale is None:
# Simple heuristic: sqrt(span) / 2 is often a safe start for density
span = jnp.max(self.x_obs, axis=0) - jnp.min(self.x_obs, axis=0)
init_ls = jnp.mean(span) * 0.2
else:
init_ls = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
# Optimize [log_ls, log_noise]
x0 = jnp.array([jnp.log(init_ls), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_matern_std(log_p, self.x_obs, y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.ls = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.ls) or jnp.isnan(self.noise):
self.ls = init_ls
self.noise = init_noise
else:
self.ls = init_ls
self.noise = init_noise
[docs]
self.alpha = _matern_solve(self.x_obs, y_centered, self.noise, self.ls)
[docs]
def __call__(self, x_query, chunk_size=500):
"""
Batched prediction for Matern.
"""
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_matern_predict(chunk, self.x_obs, self.alpha, self.ls))
return jnp.concatenate(preds, axis=0) + self.y_mean
# ==============================================================================
# 3. GRADIENT-ENHANCED MATERN
# ==============================================================================
[docs]
def matern_kernel_elem(x1, x2, length_scale=1.0):
d2 = jnp.sum((x1 - x2) ** 2)
r = jnp.sqrt(d2 + 1e-12)
ls = jnp.squeeze(length_scale)
sqrt5_r_l = jnp.sqrt(5.0) * r / ls
val = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * ls**2)) * jnp.exp(-sqrt5_r_l)
return val
# --- Auto-Diff the Kernel to get Gradient Covariances ---
# This creates a function that returns the (D+1)x(D+1) covariance block
# [ Cov(E, E) Cov(E, dX) Cov(E, dY) ]
# [ Cov(dX, E) Cov(dX, dX) Cov(dX, dY) ]
# [ Cov(dY, E) Cov(dY, dX) Cov(dY, dY) ]
[docs]
def full_covariance_matern(x1, x2, length_scale):
k_ee = matern_kernel_elem(x1, x2, length_scale)
k_ed = jax.grad(matern_kernel_elem, argnums=1)(x1, x2, length_scale)
k_de = jax.grad(matern_kernel_elem, argnums=0)(x1, x2, length_scale)
k_dd = jax.jacfwd(jax.grad(matern_kernel_elem, argnums=1), argnums=0)(
x1, x2, length_scale
)
# Top row: [E-E, E-dx, E-dy]
row1 = jnp.concatenate([k_ee[None], k_ed])
# Bottom rows: [dx-E, dx-dx, dx-dy]
# [dy-E, dy-dx, dy-dy]
row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1)
return jnp.concatenate([row1[None, :], row2], axis=0)
[docs]
k_matrix_matern_grad_map = vmap(
vmap(full_covariance_matern, (None, 0, None)), (0, None, None)
)
[docs]
def negative_mll_matern_grad(log_params, x, y_flat, D_plus_1):
length_scale = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K_blocks = k_matrix_matern_grad_map(x, x, length_scale)
N = x.shape[0]
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
return generic_negative_mll(K_full, y_flat, noise_scalar)
@jit
[docs]
def _grad_matern_solve(x, y_full, noise_scalar, length_scale):
K_blocks = k_matrix_matern_grad_map(x, x, length_scale)
N, _, D_plus_1, _ = K_blocks.shape
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
diag_noise = (noise_scalar + 1e-6) * jnp.eye(N * D_plus_1)
K_full = K_full + diag_noise
alpha = jnp.linalg.solve(K_full, y_full.flatten())
return alpha
@jit
[docs]
def _grad_matern_predict(x_query, x_obs, alpha, length_scale):
def get_query_row(xq, xo):
kee = matern_kernel_elem(xq, xo, length_scale)
ked = jax.grad(matern_kernel_elem, argnums=1)(xq, xo, length_scale)
return jnp.concatenate([kee[None], ked])
K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs)
M, N, D_plus_1 = K_q.shape
return K_q.reshape(M, N * D_plus_1) @ alpha
[docs]
class GradientMatern:
def __init__(
self,
x,
y,
gradients=None,
smoothing=1e-4,
length_scale=None,
optimize=True,
**kwargs,
):
[docs]
self.x = jnp.asarray(x, dtype=jnp.float32)
y_energies = jnp.asarray(y, dtype=jnp.float32)[:, None]
grad_vals = (
jnp.asarray(gradients, dtype=jnp.float32)
if gradients is not None
else jnp.zeros_like(self.x)
)
[docs]
self.y_full = jnp.concatenate([y_energies, grad_vals], axis=1)
[docs]
self.e_mean = jnp.mean(y_energies)
self.y_full = self.y_full.at[:, 0].add(-self.e_mean)
[docs]
self.y_flat = self.y_full.flatten()
D_plus_1 = self.x.shape[1] + 1
if length_scale is None:
span = jnp.max(self.x, axis=0) - jnp.min(self.x, axis=0)
init_ls = jnp.mean(span) * 0.5
else:
init_ls = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_ls), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_matern_grad(log_p, self.x, self.y_flat, D_plus_1)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.ls = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.ls) or jnp.isnan(self.noise):
self.ls, self.noise = init_ls, init_noise
else:
self.ls, self.noise = init_ls, init_noise
[docs]
self.alpha = _grad_matern_solve(self.x, self.y_full, self.noise, self.ls)
[docs]
def __call__(self, x_query, chunk_size=500):
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_grad_matern_predict(chunk, self.x, self.alpha, self.ls))
return jnp.concatenate(preds, axis=0) + self.e_mean
# ==============================================================================
# 4. STANDARD IMQ (Optimizable)
# ==============================================================================
@jit
[docs]
def _imq_kernel_matrix(x, epsilon):
d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)
K = 1.0 / jnp.sqrt(d2 + epsilon**2)
return K
[docs]
def negative_mll_imq_std(log_params, x, y):
epsilon = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K = _imq_kernel_matrix(x, epsilon)
return generic_negative_mll(K, y, noise_scalar)
@jit
[docs]
def _imq_solve(x, y, sm, epsilon):
K = _imq_kernel_matrix(x, epsilon)
K = K + jnp.eye(x.shape[0]) * sm
L = jnp.linalg.cholesky(K)
alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y))
return alpha
@jit
[docs]
def _imq_predict(x_query, x_obs, alpha, epsilon):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
K_q = 1.0 / jnp.sqrt(d2 + epsilon**2)
return K_q @ alpha
[docs]
class FastIMQ:
def __init__(
self, x_obs, y_obs, smoothing=1e-3, length_scale=None, optimize=True, **kwargs
):
[docs]
self.x_obs = jnp.asarray(x_obs, dtype=jnp.float32)
[docs]
self.y_obs = jnp.asarray(y_obs, dtype=jnp.float32)
[docs]
self.y_mean = jnp.mean(self.y_obs)
y_centered = self.y_obs - self.y_mean
if length_scale is None:
span = jnp.max(self.x_obs, axis=0) - jnp.min(self.x_obs, axis=0)
init_eps = jnp.mean(span) * 0.8
else:
init_eps = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_eps), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_imq_std(log_p, self.x_obs, y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.epsilon = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.epsilon) or jnp.isnan(self.noise):
self.epsilon, self.noise = init_eps, init_noise
else:
self.epsilon, self.noise = init_eps, init_noise
[docs]
self.alpha = _imq_solve(self.x_obs, y_centered, self.noise, self.epsilon)
[docs]
def __call__(self, x_query, chunk_size=500):
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_imq_predict(chunk, self.x_obs, self.alpha, self.epsilon))
return jnp.concatenate(preds, axis=0) + self.y_mean
# ==============================================================================
# 6. SQUARED EXPONENTIAL (SE) - "The Classic"
# k(r) = exp(-r^2 / (2 * l^2))
# ==============================================================================
[docs]
def se_kernel_elem(x1, x2, length_scale=1.0):
d2 = jnp.sum((x1 - x2) ** 2)
# Clamp length_scale to avoid division by zero
ls = jnp.maximum(length_scale, 1e-5)
val = jnp.exp(-d2 / (2.0 * ls**2))
return val
# Auto-diff covariance block
[docs]
def full_covariance_se(x1, x2, length_scale):
k_ee = se_kernel_elem(x1, x2, length_scale)
k_ed = jax.grad(se_kernel_elem, argnums=1)(x1, x2, length_scale)
k_de = jax.grad(se_kernel_elem, argnums=0)(x1, x2, length_scale)
k_dd = jax.jacfwd(jax.grad(se_kernel_elem, argnums=1), argnums=0)(
x1, x2, length_scale
)
row1 = jnp.concatenate([k_ee[None], k_ed])
row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1)
return jnp.concatenate([row1[None, :], row2], axis=0)
# Vectorize
[docs]
k_matrix_se_grad_map = vmap(vmap(full_covariance_se, (None, 0, None)), (0, None, None))
[docs]
def negative_mll_se_grad(log_params, x, y_flat, D_plus_1):
length_scale = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K_blocks = k_matrix_se_grad_map(x, x, length_scale)
N = x.shape[0]
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
return generic_negative_mll(K_full, y_flat, noise_scalar)
@jit
[docs]
def _grad_se_solve(x, y_full, noise_scalar, length_scale):
K_blocks = k_matrix_se_grad_map(x, x, length_scale)
N, _, D_plus_1, _ = K_blocks.shape
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
diag_noise = (noise_scalar + 1e-6) * jnp.eye(N * D_plus_1)
K_full = K_full + diag_noise
alpha = jnp.linalg.solve(K_full, y_full.flatten())
return alpha
@jit
[docs]
def _grad_se_predict(x_query, x_obs, alpha, length_scale):
def get_query_row(xq, xo):
kee = se_kernel_elem(xq, xo, length_scale)
ked = jax.grad(se_kernel_elem, argnums=1)(xq, xo, length_scale)
return jnp.concatenate([kee[None], ked])
K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs)
M, N, D_plus_1 = K_q.shape
return K_q.reshape(M, N * D_plus_1) @ alpha
[docs]
class GradientSE:
def __init__(
self,
x,
y,
gradients=None,
smoothing=1e-4,
length_scale=None,
optimize=True,
**kwargs,
):
[docs]
self.x = jnp.asarray(x, dtype=jnp.float32)
y_energies = jnp.asarray(y, dtype=jnp.float32)[:, None]
grad_vals = (
jnp.asarray(gradients, dtype=jnp.float32)
if gradients is not None
else jnp.zeros_like(self.x)
)
[docs]
self.y_full = jnp.concatenate([y_energies, grad_vals], axis=1)
[docs]
self.e_mean = jnp.mean(y_energies)
self.y_full = self.y_full.at[:, 0].add(-self.e_mean)
[docs]
self.y_flat = self.y_full.flatten()
D_plus_1 = self.x.shape[1] + 1
if length_scale is None:
span = jnp.max(self.x, axis=0) - jnp.min(self.x, axis=0)
init_ls = jnp.mean(span) * 0.4
else:
init_ls = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_ls), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_se_grad(log_p, self.x, self.y_flat, D_plus_1)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.ls = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.ls) or jnp.isnan(self.noise):
self.ls, self.noise = init_ls, init_noise
else:
self.ls, self.noise = init_ls, init_noise
[docs]
self.alpha = _grad_se_solve(self.x, self.y_full, self.noise, self.ls)
[docs]
def __call__(self, x_query, chunk_size=500):
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_grad_se_predict(chunk, self.x, self.alpha, self.ls))
return jnp.concatenate(preds, axis=0) + self.e_mean
# ==============================================================================
# 5. GRADIENT-ENHANCED IMQ (Optimizable)
# ==============================================================================
[docs]
def imq_kernel_elem(x1, x2, epsilon=1.0):
d2 = jnp.sum((x1 - x2) ** 2)
val = 1.0 / jnp.sqrt(d2 + epsilon**2)
return val
[docs]
def full_covariance_imq(x1, x2, epsilon):
k_ee = imq_kernel_elem(x1, x2, epsilon)
k_ed = jax.grad(imq_kernel_elem, argnums=1)(x1, x2, epsilon)
k_de = jax.grad(imq_kernel_elem, argnums=0)(x1, x2, epsilon)
k_dd = jax.jacfwd(jax.grad(imq_kernel_elem, argnums=1), argnums=0)(x1, x2, epsilon)
row1 = jnp.concatenate([k_ee[None], k_ed])
row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1)
return jnp.concatenate([row1[None, :], row2], axis=0)
[docs]
k_matrix_imq_grad_map = vmap(
vmap(full_covariance_imq, (None, 0, None)), (0, None, None)
)
[docs]
def negative_mll_imq_grad(log_params, x, y_flat, D_plus_1):
epsilon = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K_blocks = k_matrix_imq_grad_map(x, x, epsilon)
N = x.shape[0]
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
return generic_negative_mll(K_full, y_flat, noise_scalar)
[docs]
def negative_mll_imq_map(log_params, init_eps, x, y_flat, D_plus_1):
log_eps = log_params[0]
log_noise = log_params[1]
epsilon = jnp.exp(log_eps)
noise_scalar = jnp.exp(log_noise)
# Likelihood (Data Fit)
K_blocks = k_matrix_imq_grad_map(x, x, epsilon)
N = x.shape[0]
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
mll_cost = generic_negative_mll(K_full, y_flat, noise_scalar)
# --- Gamma Prior on Epsilon ---
# Distribution should peak at 'init_eps' but kills large values.
# Gamma PDF: x^(alpha-1) * exp(-beta * x)
# NegLogPDF: -(alpha-1)*log(x) + beta*x
alpha_g = 2.0 # Shape=2 ensures the distribution goes to 0 at epsilon=0 (physical)
beta_g = 1.0 / (init_eps + 1e-6) # Rate set so the peak (mode) is roughly at init_eps
# This linear 'epsilon' term is what stops it from shooting up
eps_penalty = -(alpha_g - 1.0) * log_eps + beta_g * epsilon
# --- Log-Normal Prior on Noise ---
# Log-Normal is fine for noise; to stay in a magnitude range
noise_target = jnp.log(1e-2)
noise_penalty = (log_noise - noise_target) ** 2 / 0.5
return mll_cost + eps_penalty + noise_penalty
@jit
[docs]
def _grad_imq_solve(x, y_full, noise_scalar, epsilon):
K_blocks = k_matrix_imq_grad_map(x, x, epsilon)
N, _, D_plus_1, _ = K_blocks.shape
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
diag_noise = (noise_scalar + 1e-6) * jnp.eye(N * D_plus_1)
K_full = K_full + diag_noise
alpha = jnp.linalg.solve(K_full, y_full.flatten())
return alpha
@jit
[docs]
def _grad_imq_predict(x_query, x_obs, alpha, epsilon):
def get_query_row(xq, xo):
kee = imq_kernel_elem(xq, xo, epsilon)
ked = jax.grad(imq_kernel_elem, argnums=1)(xq, xo, epsilon)
return jnp.concatenate([kee[None], ked])
K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs)
M, N, D_plus_1 = K_q.shape
return K_q.reshape(M, N * D_plus_1) @ alpha
[docs]
class GradientIMQ:
def __init__(
self,
x,
y,
gradients=None,
smoothing=1e-4,
length_scale=None,
optimize=True,
**kwargs,
):
[docs]
self.x = jnp.asarray(x, dtype=jnp.float32)
y_energies = jnp.asarray(y, dtype=jnp.float32)[:, None]
grad_vals = (
jnp.asarray(gradients, dtype=jnp.float32)
if gradients is not None
else jnp.zeros_like(self.x)
)
[docs]
self.y_full = jnp.concatenate([y_energies, grad_vals], axis=1)
[docs]
self.e_mean = jnp.mean(y_energies)
self.y_full = self.y_full.at[:, 0].add(-self.e_mean)
[docs]
self.y_flat = self.y_full.flatten()
D_plus_1 = self.x.shape[1] + 1
if length_scale is None:
span = jnp.max(self.x, axis=0) - jnp.min(self.x, axis=0)
init_eps = jnp.mean(span) * 0.8
else:
init_eps = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_eps), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_imq_map(log_p, init_eps, self.x, self.y_flat, D_plus_1)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.epsilon = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.epsilon) or jnp.isnan(self.noise):
self.epsilon, self.noise = init_eps, init_noise
else:
self.epsilon, self.noise = init_eps, init_noise
[docs]
self.alpha = _grad_imq_solve(self.x, self.y_full, self.noise, self.epsilon)
[docs]
def __call__(self, x_query, chunk_size=500):
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_grad_imq_predict(chunk, self.x, self.alpha, self.epsilon))
return jnp.concatenate(preds, axis=0) + self.e_mean
# ==============================================================================
# 7. RATIONAL QUADRATIC (RQ)
# k(r) = (1 + r^2 / (2 * alpha * l^2))^(-alpha)
# ==============================================================================
[docs]
def rq_kernel_base(x1, x2, length_scale, alpha):
"""Standard RQ Kernel: (1 + r^2 / (2*alpha*l^2))^-alpha"""
d2 = jnp.sum((x1 - x2) ** 2)
base = 1.0 + d2 / (2.0 * alpha * (length_scale**2) + 1e-6)
val = base ** (-alpha)
return val
[docs]
def rq_kernel_elem(x1, x2, params):
"""
SYMMETRIC KERNEL TRICK: k_sym(x, x') = k(x, x') + k(swap(x), x')
Enforces f(r, p) = f(p, r) globally.
"""
# Params: [length_scale, alpha]
length_scale = params[0]
alpha = params[1]
# Standard interaction
k_direct = rq_kernel_base(x1, x2, length_scale, alpha)
# Swapped interaction (Mirror across diagonal)
# x1[::-1] swaps (r, p) -> (p, r)
k_mirror = rq_kernel_base(x1[::-1], x2, length_scale, alpha)
# Summing them enforces symmetry in the output function
return k_direct + k_mirror
# Auto-diff covariance block
[docs]
def full_covariance_rq(x1, x2, params):
k_ee = rq_kernel_elem(x1, x2, params)
k_ed = jax.grad(rq_kernel_elem, argnums=1)(x1, x2, params)
k_de = jax.grad(rq_kernel_elem, argnums=0)(x1, x2, params)
k_dd = jax.jacfwd(jax.grad(rq_kernel_elem, argnums=1), argnums=0)(x1, x2, params)
row1 = jnp.concatenate([k_ee[None], k_ed])
row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1)
return jnp.concatenate([row1[None, :], row2], axis=0)
[docs]
k_matrix_rq_grad_map = vmap(vmap(full_covariance_rq, (None, 0, None)), (0, None, None))
# --- MAXIMUM A POSTERIORI LOSS ---
[docs]
def negative_mll_rq_map(log_params, x, y_flat, D_plus_1):
log_ls = log_params[0]
log_alpha = log_params[1]
log_noise = log_params[2]
length_scale = jnp.exp(log_ls)
alpha = jnp.exp(log_alpha)
noise_scalar = jnp.exp(log_noise)
# 1. Likelihood (Data Fit)
params = jnp.array([length_scale, alpha])
K_blocks = k_matrix_rq_grad_map(x, x, params)
N = x.shape[0]
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
mll_cost = generic_negative_mll(K_full, y_flat, noise_scalar)
# LS Prior: Target 1.5 Å (Forces global connection).
# Variance 0.05 (Very Stiff)
ls_target = jnp.log(1.5)
ls_penalty = (log_ls - ls_target) ** 2 / 0.05
# Noise Prior: Target 1e-2 (Relaxes "Exactness").
# Allows the surface to smooth out gradient conflicts (fixing bubbles).
# Variance 1.0 (Medium) -> Allows some data-driven movement.
noise_target = jnp.log(1e-2)
noise_penalty = (log_noise - noise_target) ** 2 / 1.0
# Alpha Prior: Target 0.8 (Long tails / Global structure).
alpha_target = jnp.log(0.8)
alpha_penalty = (log_alpha - alpha_target) ** 2 / 0.5
return mll_cost + ls_penalty + noise_penalty + alpha_penalty
@jit
[docs]
def _grad_rq_solve(x, y_full, noise_scalar, params):
K_blocks = k_matrix_rq_grad_map(x, x, params)
N, _, D_plus_1, _ = K_blocks.shape
K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1)
diag_noise = (noise_scalar + 1e-6) * jnp.eye(N * D_plus_1)
K_full = K_full + diag_noise
alpha = jnp.linalg.solve(K_full, y_full.flatten())
return alpha
@jit
[docs]
def _grad_rq_predict(x_query, x_obs, alpha, params):
def get_query_row(xq, xo):
kee = rq_kernel_elem(xq, xo, params)
ked = jax.grad(rq_kernel_elem, argnums=1)(xq, xo, params)
return jnp.concatenate([kee[None], ked])
K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs)
M, N, D_plus_1 = K_q.shape
return K_q.reshape(M, N * D_plus_1) @ alpha
[docs]
class GradientRQ:
def __init__(
self,
x,
y,
gradients=None,
smoothing=1e-4,
length_scale=None,
optimize=True,
**kwargs,
):
[docs]
self.x = jnp.asarray(x, dtype=jnp.float32)
y_energies = jnp.asarray(y, dtype=jnp.float32)[:, None]
grad_vals = (
jnp.asarray(gradients, dtype=jnp.float32)
if gradients is not None
else jnp.zeros_like(self.x)
)
[docs]
self.y_full = jnp.concatenate([y_energies, grad_vals], axis=1)
[docs]
self.e_mean = jnp.mean(y_energies)
self.y_full = self.y_full.at[:, 0].add(-self.e_mean)
[docs]
self.y_flat = self.y_full.flatten()
D_plus_1 = self.x.shape[1] + 1
# Initial Guesses (Seed the optimizer in the physical basin)
init_ls = length_scale if length_scale is not None else 1.5
init_alpha = 1.0
init_noise = 1e-2
if optimize:
x0 = jnp.array([jnp.log(init_ls), jnp.log(init_alpha), jnp.log(init_noise)])
def loss_fn(log_p):
# Use the MAP loss (with priors)
return negative_mll_rq_map(log_p, self.x, self.y_flat, D_plus_1)
# BFGS with Stiff Priors
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.ls = float(jnp.exp(results.x[0]))
self.alpha_param = float(jnp.exp(results.x[1]))
self.noise = float(jnp.exp(results.x[2]))
# Fallback if optimization diverges
if jnp.isnan(self.ls) or jnp.isnan(self.noise):
self.ls, self.alpha_param, self.noise = init_ls, init_alpha, init_noise
else:
self.ls, self.alpha_param, self.noise = init_ls, init_alpha, init_noise
[docs]
self.params = jnp.array([self.ls, self.alpha_param])
[docs]
self.alpha = _grad_rq_solve(self.x, self.y_full, self.noise, self.params)
[docs]
def __call__(self, x_query, chunk_size=500):
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_grad_rq_predict(chunk, self.x, self.alpha, self.params))
return jnp.concatenate(preds, axis=0) + self.e_mean
# Factory for string-based instantiation
[docs]
def get_surface_model(name):
models = {
"grad_matern": GradientMatern,
"grad_rq": GradientRQ,
"grad_se": GradientSE,
"grad_imq": GradientIMQ,
"matern": FastMatern,
"imq": FastIMQ,
"tps": FastTPS,
"rbf": FastTPS,
}
return models.get(name, GradientMatern)