Source code for rgpycrumbs.surfaces._kernels

import jax
import jax.numpy as jnp
from jax import jit, vmap

# ==============================================================================
# TPS KERNELS
# ==============================================================================


@jit
[docs] def _tps_kernel_matrix(x): d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1) r = jnp.sqrt(d2 + 1e-12) K = r**2 * jnp.log(r) return K
# ============================================================================== # MATERN KERNELS # ============================================================================== @jit
[docs] def _matern_kernel_matrix(x, length_scale): d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1) r = jnp.sqrt(d2 + 1e-12) # Matérn 5/2 Kernel # k(r) = (1 + sqrt(5)r/l + 5r^2/3l^2) * exp(-sqrt(5)r/l) sqrt5_r_l = jnp.sqrt(5.0) * r / length_scale K = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * length_scale**2)) * jnp.exp(-sqrt5_r_l) return K
[docs] def matern_kernel_elem(x1, x2, length_scale=1.0): d2 = jnp.sum((x1 - x2) ** 2) r = jnp.sqrt(d2 + 1e-12) ls = jnp.squeeze(length_scale) sqrt5_r_l = jnp.sqrt(5.0) * r / ls val = (1.0 + sqrt5_r_l + (5.0 * r**2) / (3.0 * ls**2)) * jnp.exp(-sqrt5_r_l) return val
[docs] def full_covariance_matern(x1, x2, length_scale): k_ee = matern_kernel_elem(x1, x2, length_scale) k_ed = jax.grad(matern_kernel_elem, argnums=1)(x1, x2, length_scale) k_de = jax.grad(matern_kernel_elem, argnums=0)(x1, x2, length_scale) k_dd = jax.jacfwd(jax.grad(matern_kernel_elem, argnums=1), argnums=0)( x1, x2, length_scale ) row1 = jnp.concatenate([k_ee[None], k_ed]) row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1) return jnp.concatenate([row1[None, :], row2], axis=0)
[docs] k_matrix_matern_grad_map = vmap( vmap(full_covariance_matern, (None, 0, None)), (0, None, None) )
# ============================================================================== # IMQ KERNELS # ============================================================================== @jit
[docs] def _imq_kernel_matrix(x, epsilon): d2 = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1) K = 1.0 / jnp.sqrt(d2 + epsilon**2) return K
[docs] def imq_kernel_elem(x1, x2, epsilon=1.0): d2 = jnp.sum((x1 - x2) ** 2) val = 1.0 / jnp.sqrt(d2 + epsilon**2) return val
[docs] def full_covariance_imq(x1, x2, epsilon): k_ee = imq_kernel_elem(x1, x2, epsilon) k_ed = jax.grad(imq_kernel_elem, argnums=1)(x1, x2, epsilon) k_de = jax.grad(imq_kernel_elem, argnums=0)(x1, x2, epsilon) k_dd = jax.jacfwd(jax.grad(imq_kernel_elem, argnums=1), argnums=0)(x1, x2, epsilon) row1 = jnp.concatenate([k_ee[None], k_ed]) row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1) return jnp.concatenate([row1[None, :], row2], axis=0)
[docs] k_matrix_imq_grad_map = vmap(vmap(full_covariance_imq, (None, 0, None)), (0, None, None))
# ============================================================================== # SE KERNELS # ==============================================================================
[docs] def se_kernel_elem(x1, x2, length_scale=1.0): d2 = jnp.sum((x1 - x2) ** 2) ls = jnp.maximum(length_scale, 1e-5) val = jnp.exp(-d2 / (2.0 * ls**2)) return val
[docs] def full_covariance_se(x1, x2, length_scale): k_ee = se_kernel_elem(x1, x2, length_scale) k_ed = jax.grad(se_kernel_elem, argnums=1)(x1, x2, length_scale) k_de = jax.grad(se_kernel_elem, argnums=0)(x1, x2, length_scale) k_dd = jax.jacfwd(jax.grad(se_kernel_elem, argnums=1), argnums=0)( x1, x2, length_scale ) row1 = jnp.concatenate([k_ee[None], k_ed]) row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1) return jnp.concatenate([row1[None, :], row2], axis=0)
[docs] k_matrix_se_grad_map = vmap(vmap(full_covariance_se, (None, 0, None)), (0, None, None))
# ============================================================================== # RQ KERNELS # ==============================================================================
[docs] def rq_kernel_base(x1, x2, length_scale, alpha): """Standard RQ Kernel: (1 + r^2 / (2*alpha*l^2))^-alpha""" d2 = jnp.sum((x1 - x2) ** 2) base = 1.0 + d2 / (2.0 * alpha * (length_scale**2) + 1e-6) val = base ** (-alpha) return val
[docs] def rq_kernel_elem(x1, x2, params): length_scale = params[0] alpha = params[1] k_direct = rq_kernel_base(x1, x2, length_scale, alpha) k_mirror = rq_kernel_base(x1[::-1], x2, length_scale, alpha) return k_direct + k_mirror
[docs] def full_covariance_rq(x1, x2, params): k_ee = rq_kernel_elem(x1, x2, params) k_ed = jax.grad(rq_kernel_elem, argnums=1)(x1, x2, params) k_de = jax.grad(rq_kernel_elem, argnums=0)(x1, x2, params) k_dd = jax.jacfwd(jax.grad(rq_kernel_elem, argnums=1), argnums=0)(x1, x2, params) row1 = jnp.concatenate([k_ee[None], k_ed]) row2 = jnp.concatenate([k_de[:, None], k_dd], axis=1) return jnp.concatenate([row1[None, :], row2], axis=0)
[docs] k_matrix_rq_grad_map = vmap(vmap(full_covariance_rq, (None, 0, None)), (0, None, None))