Source code for rgpycrumbs.surfaces

import importlib

_LAZY_IMPORTS = {
    # _base
    "BaseGradientSurface": "_base",
    "BaseSurface": "_base",
    "generic_negative_mll": "_base",
    "safe_cholesky_solve": "_base",
    # _kernels
    "_imq_kernel_matrix": "_kernels",
    "_matern_kernel_matrix": "_kernels",
    "_tps_kernel_matrix": "_kernels",
    # gradient (requires jax) -- per-kernel modules
    "GradientIMQ": "gradient_imq",
    "GradientMatern": "gradient_matern",
    "GradientRQ": "gradient_rq",
    "GradientSE": "gradient_se",
    "NystromGradientIMQ": "gradient_nystrom",
    # standard (requires jax)
    "FastIMQ": "standard",
    "FastMatern": "standard",
    "FastTPS": "standard",
}

# Submodules that require jax at import time
_JAX_SUBMODULES = frozenset(
    {
        "gradient",
        "gradient_imq",
        "gradient_matern",
        "gradient_nystrom",
        "gradient_rq",
        "gradient_se",
        "standard",
        "_kernels",
    }
)

[docs] NYSTROM_THRESHOLD = 1000
[docs] NYSTROM_N_INDUCING_DEFAULT = 300
[docs] def nystrom_paths_needed(n_inducing, images_per_step): """Number of optimization steps the Nystrom approximation actually samples. Mirrors the structured sampling in :class:`NystromGradientIMQ._fit`: ``paths_to_sample = max(1, n_inducing // nimags)``, plus one buffer step. """ return max(1, -(-n_inducing // images_per_step)) + 1 # ceil div + buffer
__all__ = [ "NYSTROM_N_INDUCING_DEFAULT", "NYSTROM_THRESHOLD", "BaseGradientSurface", "BaseSurface", "FastIMQ", "FastMatern", "FastTPS", "GradientIMQ", "GradientMatern", "GradientRQ", "GradientSE", "NystromGradientIMQ", "_imq_kernel_matrix", "_matern_kernel_matrix", "_tps_kernel_matrix", "generic_negative_mll", "get_surface_model", "nystrom_paths_needed", "safe_cholesky_solve", ] def __getattr__(name): if name in _LAZY_IMPORTS: target = _LAZY_IMPORTS[name] if target in _JAX_SUBMODULES: from rgpycrumbs._aux import ensure_import ensure_import("jax") submod = importlib.import_module(f"rgpycrumbs.surfaces.{target}") return getattr(submod, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
[docs] def get_surface_model(name): """ Factory function to retrieve surface model classes by name. .. versionadded:: 1.0.0 Args: name: Model identifier (e.g., 'grad_matern', 'tps', 'imq'). Returns: type: The model class. Defaults to GradientMatern. """ _models = { "grad_matern": "GradientMatern", "grad_rq": "GradientRQ", "grad_se": "GradientSE", "grad_imq": "GradientIMQ", "grad_imq_ny": "NystromGradientIMQ", "matern": "FastMatern", "imq": "FastIMQ", "tps": "FastTPS", "rbf": "FastTPS", } class_name = _models.get(name, "GradientMatern") return __getattr__(class_name)