Source code for rgpycrumbs.surfaces.gradient_se

import jax
import jax.numpy as jnp
import jax.scipy.optimize as jopt
from jax import jit, vmap

from rgpycrumbs.surfaces._base import BaseGradientSurface, generic_negative_mll
from rgpycrumbs.surfaces._kernels import (
    k_matrix_se_grad_map,
    se_kernel_elem,
)

# ==============================================================================
# GRADIENT-ENHANCED SE HELPERS
# ==============================================================================


[docs] def negative_mll_se_grad(log_params, x, y_flat, D_plus_1): length_scale = jnp.exp(log_params[0]) noise_scalar = jnp.exp(log_params[1]) K_blocks = k_matrix_se_grad_map(x, x, length_scale) N = x.shape[0] K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1) return generic_negative_mll(K_full, y_flat, noise_scalar)
@jit
[docs] def _grad_se_solve(x, y_full, noise_scalar, length_scale): K_blocks = k_matrix_se_grad_map(x, x, length_scale) N, _, D_plus_1, _ = K_blocks.shape K_full = K_blocks.transpose(0, 2, 1, 3).reshape(N * D_plus_1, N * D_plus_1) diag_noise = (noise_scalar + 1e-6) * jnp.eye(N * D_plus_1) K_full = K_full + diag_noise K_inv = jnp.linalg.inv(K_full) alpha = jnp.linalg.solve(K_full, y_full.flatten()) return alpha, K_inv
@jit
[docs] def _grad_se_predict(x_query, x_obs, alpha, length_scale): def get_query_row(xq, xo): kee = se_kernel_elem(xq, xo, length_scale) ked = jax.grad(se_kernel_elem, argnums=1)(xq, xo, length_scale) return jnp.concatenate([kee[None], ked]) K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs) M, N, D_plus_1 = K_q.shape return K_q.reshape(M, N * D_plus_1) @ alpha
@jit
[docs] def _grad_se_var(x_query, x_obs, K_inv, length_scale): def get_query_row(xq, xo): kee = se_kernel_elem(xq, xo, length_scale) ked = jax.grad(se_kernel_elem, argnums=1)(xq, xo, length_scale) return jnp.concatenate([kee[None], ked]) K_q = vmap(vmap(get_query_row, (None, 0)), (0, None))(x_query, x_obs) M, N, D_plus_1 = K_q.shape K_q_flat = K_q.reshape(M, N * D_plus_1) var = 1.0 - jnp.sum((K_q_flat @ K_inv) * K_q_flat, axis=1) return jnp.maximum(var, 0.0)
[docs] class GradientSE(BaseGradientSurface): """Gradient-enhanced Squared Exponential (SE) surface implementation. .. versionadded:: 1.0.0 """
[docs] def _fit(self, smoothing, length_scale, optimize): if length_scale is None: span = jnp.max(self.x, axis=0) - jnp.min(self.x, axis=0) init_ls = jnp.mean(span) * 0.4 else: init_ls = length_scale init_noise = max(smoothing, 1e-4) if optimize: x0 = jnp.array([jnp.log(init_ls), jnp.log(init_noise)]) def loss_fn(log_p): return negative_mll_se_grad(log_p, self.x, self.y_flat, self.D_plus_1) results = jopt.minimize(loss_fn, x0, method="BFGS", tol=1e-3) self.ls = float(jnp.exp(results.x[0])) self.noise = float(jnp.exp(results.x[1])) if jnp.isnan(self.ls) or jnp.isnan(self.noise): self.ls, self.noise = init_ls, init_noise else: self.ls, self.noise = init_ls, init_noise
[docs] def _solve(self): self.alpha, self.K_inv = _grad_se_solve(self.x, self.y_full, self.noise, self.ls)
[docs] def _predict_chunk(self, chunk): return _grad_se_predict(chunk, self.x, self.alpha, self.ls)
[docs] def _var_chunk(self, chunk): return _grad_se_var(chunk, self.x, self.K_inv, self.ls)