Source code for rgpycrumbs._aux

import contextlib
import importlib
import logging
import os
import shutil
import subprocess
import sys
from pathlib import Path

[docs] logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # Dependency registry # --------------------------------------------------------------------------- # Maps importable module names to (pip_spec, extra_name). # Conda-only deps (ira_mod, tblite, ovito) are intentionally absent; # they fall through to a helpful error suggesting pixi.
[docs] _DEPENDENCY_MAP: dict[str, tuple[str, str]] = { "jax": ("jax>=0.4", "surfaces"), "jaxlib": ("jax>=0.4", "surfaces"), "scipy": ("scipy>=1.11", "interpolation"), "scipy.interpolate": ("scipy>=1.11", "interpolation"), "scipy.spatial": ("scipy>=1.11", "analysis"), "scipy.spatial.distance": ("scipy>=1.11", "analysis"), "ase": ("ase>=3.22", "analysis"), "ase.data": ("ase>=3.22", "analysis"), "ase.neighborlist": ("ase>=3.22", "analysis"), "adjustText": ("adjustText>=1.0", "analysis"), "chemparseplot": ("chemparseplot", "analysis"), "chemparseplot.plot.chemgp": ("chemparseplot", "analysis"), "h5py": ("h5py", "analysis"), "matplotlib": ("matplotlib>=3.7", "analysis"), "matplotlib.pyplot": ("matplotlib>=3.7", "analysis"), "pandas": ("pandas>=2.0", "analysis"), "polars": ("polars>=1.0", "analysis"), }
# CPU-only pip spec overrides for packages with heavy GPU backends. # Applied when no CUDA device is detected to avoid pulling hundreds of # megabytes of CUDA libraries.
[docs] _CPU_OVERRIDES: dict[str, str] = { "jax": "jax[cpu]>=0.4", "jaxlib": "jax[cpu]>=0.4", }
# Cache the result of the CUDA probe so it runs at most once per process.
[docs] _cuda_available: bool | None = None
[docs] def _has_cuda() -> bool: """Return True when a usable NVIDIA GPU is present. Checks for ``nvidia-smi`` on PATH and verifies it exits cleanly. The result is cached for the lifetime of the process. """ global _cuda_available if _cuda_available is not None: return _cuda_available nvsmi = shutil.which("nvidia-smi") if nvsmi is None: _cuda_available = False return False try: subprocess.run( # noqa: S603 [nvsmi], check=True, capture_output=True, timeout=5, ) _cuda_available = True except (subprocess.CalledProcessError, subprocess.TimeoutExpired, OSError): _cuda_available = False return _cuda_available
[docs] def _get_dep_cache_dir() -> Path: """Return the per-user dependency cache directory. Defaults to ``$XDG_CACHE_HOME/rgpycrumbs/deps/`` (typically ``~/.cache/rgpycrumbs/deps/``). """ xdg = os.environ.get("XDG_CACHE_HOME", "") base = Path(xdg) if xdg else Path.home() / ".cache" return base / "rgpycrumbs" / "deps"
[docs] def _resolve_pip_spec(module_name: str) -> str: """Return the pip install spec for *module_name*, respecting CUDA. If the host lacks a CUDA device and a CPU-only override exists, the override is returned instead of the default spec. """ spec, _extra = _DEPENDENCY_MAP[module_name] if not _has_cuda(): base_pkg = module_name.split(".", maxsplit=1)[0] spec = _CPU_OVERRIDES.get(base_pkg, spec) return spec
[docs] def _uv_install(package_spec: str, target: Path) -> None: """Install *package_spec* into *target* using uv (falling back to pip). Raises ``RuntimeError`` if both installers fail. """ target.mkdir(parents=True, exist_ok=True) for installer in ("uv", "pip"): exe = shutil.which(installer) if exe is None: continue cmd = [exe, "pip", "install", "--target", str(target), package_spec] if installer == "pip": cmd = [exe, "install", "--target", str(target), package_spec] logger.info("rgpycrumbs: installing %s via %s", package_spec, installer) try: subprocess.run( # noqa: S603 cmd, check=True, capture_output=True, ) return except (subprocess.CalledProcessError, OSError) as exc: logger.debug("%s install failed: %s", installer, exc) continue msg = f"Failed to install {package_spec}. Ensure uv or pip is available on PATH." raise RuntimeError(msg)
[docs] def ensure_import(module_name: str): """Import *module_name* through a 5-step priority chain. 1. Current environment (importlib) 2. Parent environment (RGPYCRUMBS_PARENT_SITE_PACKAGES) 3. uv cache directory on sys.path 4. uv/pip install into cache (opt-in via RGPYCRUMBS_AUTO_DEPS=1) 5. Raise ImportError with an actionable message Returns the imported module object. .. versionadded:: 1.3.0 """ # Step 1: current env try: return importlib.import_module(module_name) except ImportError: pass # Step 2: parent env mod = _import_from_parent_env(module_name) if mod is not None: return mod # Step 3: check uv cache cache_dir = _get_dep_cache_dir() cache_str = str(cache_dir) if cache_dir.is_dir() and cache_str not in sys.path: sys.path.insert(0, cache_str) try: return importlib.import_module(module_name) except ImportError: pass # Step 4: auto-install (opt-in) auto = os.environ.get("RGPYCRUMBS_AUTO_DEPS", "").strip() if auto == "1" and module_name in _DEPENDENCY_MAP: spec = _resolve_pip_spec(module_name) _uv_install(spec, cache_dir) if cache_str not in sys.path: sys.path.insert(0, cache_str) try: return importlib.import_module(module_name) except ImportError: pass # Step 5: actionable error if module_name in _DEPENDENCY_MAP: _spec, extra = _DEPENDENCY_MAP[module_name] # Special handling for JAX with detailed instructions if module_name in ("jax", "jaxlib"): msg = """ JAX is required for surface fitting and Gaussian Process models. Quick install: pip install "rgpycrumbs[surfaces]" Or enable auto-install: export RGPYCRUMBS_AUTO_DEPS=1 For GPU support: pip install "jax[cuda12]" # CUDA 12 pip install "jax[cuda11]" # CUDA 11 See: https://jax.readthedocs.io/en/latest/installation.html """ else: msg = ( f"Module '{module_name}' is required.\n\n" f"Install with:\n" f" pip install rgpycrumbs[{extra}]\n\n" f"Or enable auto-install:\n" f" export RGPYCRUMBS_AUTO_DEPS=1" ) else: msg = ( f"Module '{module_name}' is not installed and is not a " "pip-installable dependency of rgpycrumbs. " "For conda-only packages (ira_mod, tblite, ovito), " "use pixi: pixi install" ) raise ImportError(msg)
[docs] class _LazyModule: """Proxy that defers ``ensure_import`` until first attribute access. After resolution the proxy replaces its own ``__dict__`` with the real module's attributes so subsequent access carries zero overhead. .. versionadded:: 1.3.0 """ def __init__(self, module_name: str): object.__setattr__(self, "_module_name", module_name) object.__setattr__(self, "_module", None)
[docs] def _resolve(self): mod = object.__getattribute__(self, "_module") if mod is None: name = object.__getattribute__(self, "_module_name") mod = ensure_import(name) object.__setattr__(self, "_module", mod) return mod
[docs] def __getattr__(self, attr): return getattr(self._resolve(), attr)
[docs] def __repr__(self): name = object.__getattribute__(self, "_module_name") mod = object.__getattribute__(self, "_module") if mod is None: return f"<LazyModule '{name}' (unresolved)>" return repr(mod)
[docs] def lazy_import(module_name: str) -> _LazyModule: """Return a lazy proxy for *module_name*. The actual import (via :func:`ensure_import`) is deferred until the first attribute access on the returned object. .. versionadded:: 1.3.0 """ return _LazyModule(module_name)
[docs] def getstrform(pathobj): """Return the absolute path as a string. .. versionadded:: 0.0.1 """ return str(pathobj.absolute())
[docs] def get_gitroot(): """Return the root of the current git repository as a Path. .. versionadded:: 0.0.1 """ git_path = shutil.which("git") or "git" gitroot = Path( subprocess.run( # noqa: S603 [git_path, "rev-parse", "--show-toplevel"], check=True, capture_output=True, cwd=Path.cwd(), ) .stdout.decode("utf-8") .strip() ) return gitroot
@contextlib.contextmanager
[docs] def switchdir(path): """Context manager that temporarily changes the working directory. .. versionadded:: 0.0.1 """ curpath = Path.cwd() os.chdir(path) try: yield finally: os.chdir(curpath)
[docs] def _import_from_parent_env(module_name: str): """ Import a module from parent interpreter's site-packages as a fallback. Uses importlib to correctly handle nested modules (e.g. 'tblite.interface'). """ # 1. Try current environment try: return importlib.import_module(module_name) except ImportError: pass # 2. Check parent environment parent_paths = os.environ.get("RGPYCRUMBS_PARENT_SITE_PACKAGES", "") if not parent_paths: return None # 3. Temporarily extend sys.path # Filter out empty strings and paths already in sys.path paths_to_add = [p for p in parent_paths.split(os.pathsep) if p and p not in sys.path] sys.path.extend(paths_to_add) try: # importlib.import_module returns the actual leaf module (interface) # __import__ would have returned the top-level package (tblite) return importlib.import_module(module_name) except ImportError: return None finally: # Clean up sys.path for p in paths_to_add: try: sys.path.remove(p) except ValueError: pass