import jax.numpy as jnp
import jax.scipy.optimize as jopt
from jax import jit
from rgpycrumbs.surfaces._base import BaseSurface, generic_negative_mll
from rgpycrumbs.surfaces._kernels import (
_imq_kernel_matrix,
_matern_kernel_matrix,
_tps_kernel_matrix,
)
# ==============================================================================
# TPS HELPERS
# ==============================================================================
[docs]
def negative_mll_tps(log_params, x, y):
# TPS only really has a smoothing parameter to tune in this context
# (Length scale is inherent to the radial basis).
smoothing = jnp.exp(log_params[0])
K = _tps_kernel_matrix(x)
return generic_negative_mll(K, y, smoothing)
@jit
[docs]
def _tps_solve(x, y, sm):
K = _tps_kernel_matrix(x)
K = K + jnp.eye(x.shape[0]) * sm
# Polynomial Matrix
N = x.shape[0]
P = jnp.concatenate([jnp.ones((N, 1), dtype=jnp.float32), x], axis=1)
M = P.shape[1]
# Solve System
zeros = jnp.zeros((M, M), dtype=jnp.float32)
top = jnp.concatenate([K, P], axis=1)
bot = jnp.concatenate([P.T, zeros], axis=1)
lhs = jnp.concatenate([top, bot], axis=0)
rhs = jnp.concatenate([y, jnp.zeros(M, dtype=jnp.float32)])
coeffs = jnp.linalg.solve(lhs, rhs)
lhs_inv = jnp.linalg.inv(lhs)
return coeffs[:N], coeffs[N:], lhs_inv
@jit
[docs]
def _tps_predict(x_query, x_obs, w, v):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
K_q = r**2 * jnp.log(r)
P_q = jnp.concatenate(
[jnp.ones((x_query.shape[0], 1), dtype=jnp.float32), x_query], axis=1
)
return K_q @ w + P_q @ v
@jit
[docs]
def _tps_var(x_query, x_obs, lhs_inv):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
K_q = r**2 * jnp.log(r)
P_q = jnp.concatenate(
[jnp.ones((x_query.shape[0], 1), dtype=jnp.float32), x_query], axis=1
)
KP_q = jnp.concatenate([K_q, P_q], axis=1)
var = -jnp.sum((KP_q @ lhs_inv) * KP_q, axis=1)
return jnp.maximum(var, 0.0)
[docs]
class FastTPS:
"""
Thin Plate Spline (TPS) surface implementation.
Includes a polynomial mean function and supports smoothing optimization.
"""
def __init__(self, x_obs, y_obs, smoothing=1e-3, optimize=True, **_kwargs):
"""
Initializes the TPS model.
Args:
x_obs: Training inputs (N, D).
y_obs: Training observations (N,).
smoothing: Initial smoothing parameter.
optimize: Whether to optimize the smoothing parameter.
"""
[docs]
self.x_obs = jnp.asarray(x_obs, dtype=jnp.float32)
[docs]
self.y_obs = jnp.asarray(y_obs, dtype=jnp.float32)
# TPS handles mean via polynomial, but centering helps optimization stability
[docs]
self.y_mean = jnp.mean(self.y_obs)
y_centered = self.y_obs - self.y_mean
init_sm = max(smoothing, 1e-4)
if optimize:
# Optimize [log_smoothing]
x0 = jnp.array([jnp.log(init_sm)])
def loss_fn(log_p):
return negative_mll_tps(log_p, self.x_obs, y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.sm = float(jnp.exp(results.x[0]))
if jnp.isnan(self.sm):
self.sm = init_sm
else:
self.sm = init_sm
self.w, self.v, self.K_inv = _tps_solve(self.x_obs, self.y_obs, self.sm)
[docs]
def __call__(self, x_query, chunk_size=500):
"""
Predict values at query points using chunking.
Args:
x_query: Query inputs (M, D).
chunk_size: Processing batch size.
Returns:
jnp.ndarray: Predicted values (M,).
"""
x_query = jnp.asarray(x_query, dtype=jnp.float32)
preds = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
preds.append(_tps_predict(chunk, self.x_obs, self.w, self.v))
return jnp.concatenate(preds, axis=0)
[docs]
def predict_var(self, x_query, chunk_size=500):
"""
Predict posterior variance at query points.
Args:
x_query: Query inputs (M, D).
chunk_size: Processing batch size.
Returns:
jnp.ndarray: Predicted variances (M,).
"""
x_query = jnp.asarray(x_query, dtype=jnp.float32)
vars_list = []
for i in range(0, x_query.shape[0], chunk_size):
chunk = x_query[i : i + chunk_size]
vars_list.append(_tps_var(chunk, self.x_obs, self.K_inv))
return jnp.concatenate(vars_list, axis=0)
# ==============================================================================
# MATERN 5/2
# ==============================================================================
[docs]
def negative_mll_matern_std(log_params, x, y):
length_scale = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K = _matern_kernel_matrix(x, length_scale)
return generic_negative_mll(K, y, noise_scalar)
@jit
[docs]
def _matern_solve(x, y, sm, length_scale):
K = _matern_kernel_matrix(x, length_scale)
K = K + jnp.eye(x.shape[0]) * sm
L = jnp.linalg.cholesky(K)
alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y))
eye = jnp.eye(K.shape[0])
L_inv = jnp.linalg.solve(L, eye)
K_inv = L_inv.T @ L_inv
return alpha, K_inv
@jit
[docs]
def _matern_predict(x_query, x_obs, alpha, length_scale):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
sqrt5_r_l = jnp.sqrt(5.0) * r / length_scale
K_q = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * length_scale**2)) * jnp.exp(-sqrt5_r_l)
return K_q @ alpha
@jit
[docs]
def _matern_var(x_query, x_obs, K_inv, length_scale):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
r = jnp.sqrt(d2 + 1e-12)
sqrt5_r_l = jnp.sqrt(5.0) * r / length_scale
K_q = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * length_scale**2)) * jnp.exp(-sqrt5_r_l)
var = 1.0 - jnp.sum((K_q @ K_inv) * K_q, axis=1)
return jnp.maximum(var, 0.0)
[docs]
class FastMatern(BaseSurface):
"""Matérn 5/2 surface implementation."""
[docs]
def _fit(self, smoothing, length_scale, optimize):
if length_scale is None:
span = jnp.max(self.x_obs, axis=0) - jnp.min(self.x_obs, axis=0)
init_ls = jnp.mean(span) * 0.2
else:
init_ls = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_ls), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_matern_std(log_p, self.x_obs, self.y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.ls = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.ls) or jnp.isnan(self.noise):
self.ls, self.noise = init_ls, init_noise
else:
self.ls, self.noise = init_ls, init_noise
[docs]
def _solve(self):
self.alpha, self.K_inv = _matern_solve(
self.x_obs, self.y_centered, self.noise, self.ls
)
[docs]
def _predict_chunk(self, chunk):
return _matern_predict(chunk, self.x_obs, self.alpha, self.ls)
[docs]
def _var_chunk(self, chunk):
return _matern_var(chunk, self.x_obs, self.K_inv, self.ls)
# ==============================================================================
# STANDARD IMQ (Optimizable)
# ==============================================================================
[docs]
def negative_mll_imq_std(log_params, x, y):
epsilon = jnp.exp(log_params[0])
noise_scalar = jnp.exp(log_params[1])
K = _imq_kernel_matrix(x, epsilon)
return generic_negative_mll(K, y, noise_scalar)
@jit
[docs]
def _imq_solve(x, y, sm, epsilon):
K = _imq_kernel_matrix(x, epsilon)
K = K + jnp.eye(x.shape[0]) * sm
L = jnp.linalg.cholesky(K)
alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y))
eye = jnp.eye(K.shape[0])
L_inv = jnp.linalg.solve(L, eye)
K_inv = L_inv.T @ L_inv
return alpha, K_inv
@jit
[docs]
def _imq_predict(x_query, x_obs, alpha, epsilon):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
K_q = 1.0 / jnp.sqrt(d2 + epsilon**2)
return K_q @ alpha
@jit
[docs]
def _imq_var(x_query, x_obs, K_inv, epsilon):
d2 = jnp.sum((x_query[:, None, :] - x_obs[None, :, :]) ** 2, axis=-1)
K_q = 1.0 / jnp.sqrt(d2 + epsilon**2)
var = (1.0 / epsilon) - jnp.sum((K_q @ K_inv) * K_q, axis=1)
return jnp.maximum(var, 0.0)
[docs]
class FastIMQ(BaseSurface):
"""Inverse Multi-Quadratic (IMQ) surface implementation."""
[docs]
def _fit(self, smoothing, length_scale, optimize):
if length_scale is None:
span = jnp.max(self.x_obs, axis=0) - jnp.min(self.x_obs, axis=0)
init_eps = jnp.mean(span) * 0.8
else:
init_eps = length_scale
init_noise = max(smoothing, 1e-4)
if optimize:
x0 = jnp.array([jnp.log(init_eps), jnp.log(init_noise)])
def loss_fn(log_p):
return negative_mll_imq_std(log_p, self.x_obs, self.y_centered)
results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3)
self.epsilon = float(jnp.exp(results.x[0]))
self.noise = float(jnp.exp(results.x[1]))
if jnp.isnan(self.epsilon) or jnp.isnan(self.noise):
self.epsilon, self.noise = init_eps, init_noise
else:
self.epsilon, self.noise = init_eps, init_noise
[docs]
def _solve(self):
self.alpha, self.K_inv = _imq_solve(
self.x_obs, self.y_centered, self.noise, self.epsilon
)
[docs]
def _predict_chunk(self, chunk):
return _imq_predict(chunk, self.x_obs, self.alpha, self.epsilon)
[docs]
def _var_chunk(self, chunk):
return _imq_var(chunk, self.x_obs, self.K_inv, self.epsilon)