Source code for rgpycrumbs.surfaces._base

import logging

import jax.numpy as jnp


[docs] def safe_cholesky_solve(K, y, noise_scalar, jitter_steps=3): """ Retries Cholesky decomposition with increasing jitter if it fails. Args: K: Covariance matrix. y: Observation vector. noise_scalar: Initial noise level. jitter_steps: Number of retry attempts with increasing jitter. Returns: tuple: (alpha, log_det) where alpha is the solution vector and log_det is the log determinant of the jittered matrix. """ N = K.shape[0] # Try successively larger jitters: 1e-6, 1e-5, 1e-4 for i in range(jitter_steps): jitter = (noise_scalar + 10 ** (-6 + i)) * jnp.eye(N) try: L = jnp.linalg.cholesky(K + jitter) alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y)) log_det = 2.0 * jnp.sum(jnp.log(jnp.diag(L))) return alpha, log_det except Exception as e: logging.debug(f"Cholesky failed: {e}") continue # Fallback for compilation safety (NaN propagation) return jnp.zeros_like(y), jnp.nan
[docs] def generic_negative_mll(K, y, noise_scalar): """ Calculates the negative Marginal Log-Likelihood (MLL). Args: K: Covariance matrix. y: Observation vector. noise_scalar: Noise level for regularization. Returns: float: The negative MLL value, or a high penalty if Cholesky fails. """ alpha, log_det = safe_cholesky_solve(K, y, noise_scalar) data_fit = 0.5 * jnp.dot(y.flatten(), alpha.flatten()) complexity = 0.5 * log_det cost = data_fit + complexity # heavy penalty if Cholesky failed (NaN) return jnp.where(jnp.isnan(cost), 1e9, cost)
[docs] class BaseSurface: """ Abstract base class for standard (non-gradient) surface models. Derived classes must implement `_fit`, `_solve`, `_predict_chunk`, and `_var_chunk`. """ def __init__( self, x_obs, y_obs, smoothing=1e-3, length_scale=None, optimize=True, **_kwargs ): """ Initializes and fits the surface model. Args: x_obs: Training inputs (N, D). y_obs: Training observations (N,). smoothing: Initial noise/smoothing parameter. length_scale: Initial length scale parameter(s). optimize: Whether to optimize parameters via MLE. **kwargs: Additional model-specific parameters. """
[docs] self.x_obs = jnp.asarray(x_obs, dtype=jnp.float32)
[docs] self.y_obs = jnp.asarray(y_obs, dtype=jnp.float32)
# Center the data
[docs] self.y_mean = jnp.mean(self.y_obs)
[docs] self.y_centered = self.y_obs - self.y_mean
self._fit(smoothing, length_scale, optimize) self._solve()
[docs] def _fit(self, smoothing, length_scale, optimize): """Internal method to perform parameter optimization.""" raise NotImplementedError
[docs] def _solve(self): """Internal method to solve the linear system for weights.""" raise NotImplementedError
[docs] def __call__(self, x_query, chunk_size=500): """ Predict values at query points. Args: x_query: Query inputs (M, D). chunk_size: Number of points to process per batch to avoid OOM. Returns: jnp.ndarray: Predicted values (M,). """ x_query = jnp.asarray(x_query, dtype=jnp.float32) preds = [] for i in range(0, x_query.shape[0], chunk_size): chunk = x_query[i : i + chunk_size] preds.append(self._predict_chunk(chunk)) return jnp.concatenate(preds, axis=0) + self.y_mean
[docs] def predict_var(self, x_query, chunk_size=500): """ Predict posterior variance at query points. Args: x_query: Query inputs (M, D). chunk_size: Number of points to process per batch. Returns: jnp.ndarray: Predicted variances (M,). """ x_query = jnp.asarray(x_query, dtype=jnp.float32) vars_list = [] for i in range(0, x_query.shape[0], chunk_size): chunk = x_query[i : i + chunk_size] vars_list.append(self._var_chunk(chunk)) return jnp.concatenate(vars_list, axis=0)
[docs] def _predict_chunk(self, chunk): """Internal method for batch prediction.""" raise NotImplementedError
[docs] def _var_chunk(self, chunk): """Internal method for batch variance.""" raise NotImplementedError
[docs] class BaseGradientSurface: """ Abstract base class for gradient-enhanced surface models. Derived classes must implement `_fit`, `_solve`, `_predict_chunk`, and `_var_chunk`. These models incorporate both values and their gradients into the fit. """ def __init__( self, x, y, gradients=None, smoothing=1e-4, length_scale=None, optimize=True, **_kwargs, ): """ Initializes and fits the gradient-enhanced surface model. Args: x: Training inputs (N, D). y: Training values (N,). gradients: Training gradients (N, D). smoothing: Initial noise/smoothing parameter. length_scale: Initial length scale parameter(s). optimize: Whether to optimize parameters. **kwargs: Additional model-specific parameters. """
[docs] self.x = jnp.asarray(x, dtype=jnp.float32)
y_energies = jnp.asarray(y, dtype=jnp.float32)[:, None] grad_vals = ( jnp.asarray(gradients, dtype=jnp.float32) if gradients is not None else jnp.zeros_like(self.x) )
[docs] self.y_full = jnp.concatenate([y_energies, grad_vals], axis=1)
[docs] self.e_mean = jnp.mean(y_energies)
self.y_full = self.y_full.at[:, 0].add(-self.e_mean)
[docs] self.y_flat = self.y_full.flatten()
[docs] self.D_plus_1 = self.x.shape[1] + 1
self._fit(smoothing, length_scale, optimize) self._solve()
[docs] def _fit(self, smoothing, length_scale, optimize): """Internal method to perform parameter optimization.""" raise NotImplementedError
[docs] def _solve(self): """Internal method to solve the linear system for weights.""" raise NotImplementedError
[docs] def __call__(self, x_query, chunk_size=500): """ Predict values at query points. Args: x_query: Query inputs (M, D). chunk_size: Number of points to process per batch. Returns: jnp.ndarray: Predicted values (M,). """ x_query = jnp.asarray(x_query, dtype=jnp.float32) preds = [] for i in range(0, x_query.shape[0], chunk_size): chunk = x_query[i : i + chunk_size] preds.append(self._predict_chunk(chunk)) return jnp.concatenate(preds, axis=0) + self.e_mean
[docs] def predict_var(self, x_query, chunk_size=500): """ Predict posterior variance at query points. Args: x_query: Query inputs (M, D). chunk_size: Number of points to process per batch. Returns: jnp.ndarray: Predicted variances (M,). """ x_query = jnp.asarray(x_query, dtype=jnp.float32) vars_list = [] for i in range(0, x_query.shape[0], chunk_size): chunk = x_query[i : i + chunk_size] vars_list.append(self._var_chunk(chunk)) return jnp.concatenate(vars_list, axis=0)
[docs] def _predict_chunk(self, chunk): """Internal method for batch prediction.""" raise NotImplementedError
[docs] def _var_chunk(self, chunk): """Internal method for batch variance.""" raise NotImplementedError