from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
import numpy as np
if TYPE_CHECKING:
pass
Array = Any
[docs]
@runtime_checkable
class ArrayBackend(Protocol):
@property
def float64(self) -> Any: ...
[docs]
def zeros(self, shape: Any, dtype: Any = np.float64) -> Array: ...
[docs]
def ones(self, shape: Any, dtype: Any = np.float64) -> Array: ...
[docs]
def asarray(self, x: Any, dtype: Any = None) -> Array: ...
[docs]
def square(self, x: Any) -> Array: ...
[docs]
def sqrt(self, x: Any) -> Array: ...
[docs]
def exp(self, x: Any) -> Array: ...
[docs]
def sum(self, x: Any, axis: Any = None) -> Array: ...
[docs]
def where(self, condition: Any, x: Any, y: Any) -> Array: ...
[docs]
def maximum(self, x: Any, y: Any) -> Array: ...
[docs]
def einsum(self, subscripts: str, *operands: Any) -> Array: ...
[docs]
def isnan(self, x: Any) -> Array: ...
[docs]
def isinf(self, x: Any) -> Array: ...
[docs]
def isfinite(self, x: Any) -> Array: ...
[docs]
def any(self, x: Any) -> bool: ...
[docs]
def all(self, x: Any) -> bool: ...
[docs]
def linspace(self, start: Any, stop: Any, num: int) -> Array: ...
[docs]
def logspace(self, start: Any, stop: Any, num: int) -> Array: ...
[docs]
def argsort(self, x: Any) -> Array: ...
[docs]
def ascontiguousarray(self, x: Any) -> Array: ...
[docs]
class NumpyBackend:
@property
def float64(self) -> Any:
return np.float64
[docs]
def zeros(self, shape: Any, dtype: Any = np.float64) -> Array:
return np.zeros(shape, dtype=dtype)
[docs]
def ones(self, shape: Any, dtype: Any = np.float64) -> Array:
return np.ones(shape, dtype=dtype)
[docs]
def asarray(self, x: Any, dtype: Any = None) -> Array:
return np.asarray(x, dtype=dtype)
[docs]
def square(self, x: Any) -> Array:
return np.square(x)
[docs]
def sqrt(self, x: Any) -> Array:
return np.sqrt(x)
[docs]
def exp(self, x: Any) -> Array:
return np.exp(x)
[docs]
def sum(self, x: Any, axis: Any = None) -> Array:
return np.sum(x, axis=axis)
[docs]
def where(self, condition: Any, x: Any, y: Any) -> Array:
return np.where(condition, x, y)
[docs]
def maximum(self, x: Any, y: Any) -> Array:
return np.maximum(x, y)
[docs]
def einsum(self, subscripts: str, *operands: Any) -> Array:
return np.einsum(subscripts, *operands)
[docs]
def isnan(self, x: Any) -> Array:
return np.isnan(x)
[docs]
def isinf(self, x: Any) -> Array:
return np.isinf(x)
[docs]
def isfinite(self, x: Any) -> Array:
return np.isfinite(x)
[docs]
def any(self, x: Any) -> bool:
return bool(np.any(x))
[docs]
def all(self, x: Any) -> bool:
return bool(np.all(x))
[docs]
def linspace(self, start: Any, stop: Any, num: int) -> Array:
return np.linspace(start, stop, num)
[docs]
def logspace(self, start: Any, stop: Any, num: int) -> Array:
return np.logspace(start, stop, num)
[docs]
def argsort(self, x: Any) -> Array:
return np.argsort(x)
[docs]
def ascontiguousarray(self, x: Any) -> Array:
return np.ascontiguousarray(x)
[docs]
class JaxBackend:
[docs]
def __init__(self) -> None:
import jax # type: ignore[import-not-found]
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp # type: ignore[import-not-found]
self._jnp = jnp
@property
def float64(self) -> Any:
return self._jnp.float64
[docs]
def zeros(self, shape: Any, dtype: Any = None) -> Array:
dtype = dtype if dtype is not None else self._jnp.float64
return self._jnp.zeros(shape, dtype=dtype)
[docs]
def ones(self, shape: Any, dtype: Any = None) -> Array:
dtype = dtype if dtype is not None else self._jnp.float64
return self._jnp.ones(shape, dtype=dtype)
[docs]
def asarray(self, x: Any, dtype: Any = None) -> Array:
return self._jnp.asarray(x, dtype=dtype)
[docs]
def square(self, x: Any) -> Array:
return self._jnp.square(x)
[docs]
def sqrt(self, x: Any) -> Array:
return self._jnp.sqrt(x)
[docs]
def exp(self, x: Any) -> Array:
return self._jnp.exp(x)
[docs]
def sum(self, x: Any, axis: Any = None) -> Array:
return self._jnp.sum(x, axis=axis)
[docs]
def where(self, condition: Any, x: Any, y: Any) -> Array:
return self._jnp.where(condition, x, y)
[docs]
def maximum(self, x: Any, y: Any) -> Array:
return self._jnp.maximum(x, y)
[docs]
def einsum(self, subscripts: str, *operands: Any) -> Array:
return self._jnp.einsum(subscripts, *operands)
[docs]
def isnan(self, x: Any) -> Array:
return self._jnp.isnan(x)
[docs]
def isinf(self, x: Any) -> Array:
return self._jnp.isinf(x)
[docs]
def isfinite(self, x: Any) -> Array:
return self._jnp.isfinite(x)
[docs]
def any(self, x: Any) -> bool:
return bool(self._jnp.any(x))
[docs]
def all(self, x: Any) -> bool:
return bool(self._jnp.all(x))
[docs]
def linspace(self, start: Any, stop: Any, num: int) -> Array:
return self._jnp.linspace(start, stop, num)
[docs]
def logspace(self, start: Any, stop: Any, num: int) -> Array:
return self._jnp.logspace(start, stop, num)
[docs]
def argsort(self, x: Any) -> Array:
return self._jnp.argsort(x)
[docs]
def ascontiguousarray(self, x: Any) -> Array:
return self._jnp.asarray(x)
def _has_nvidia_gpu() -> bool:
"""Fast check for NVIDIA GPU without importing JAX (~1ms)."""
import shutil
return shutil.which("nvidia-smi") is not None
def _auto_select_backend() -> ArrayBackend:
"""Select the best available backend: JAX GPU > JAX CPU > NumPy.
JAX is preferred when a GPU is available (10-50x speedup on large arrays).
On CPU-only systems, NumPy is faster for typical workloads (<5000 points)
due to lower per-call dispatch overhead vs JAX's XLA runtime.
Avoids importing JAX on CPU-only systems to keep cold start fast (~300ms
vs ~900ms with JAX import). Only pays the JAX import cost when a GPU is
detected and JAX is installed.
"""
import importlib.util
if not importlib.util.find_spec("jax"):
return NumpyBackend()
if not _has_nvidia_gpu():
return NumpyBackend()
# GPU detected and JAX installed — pay the import cost
try:
import jax # type: ignore[import-not-found]
if jax.default_backend() == "gpu":
return JaxBackend()
except (ImportError, RuntimeError):
pass
return NumpyBackend()
_backend: ArrayBackend = _auto_select_backend()
[docs]
def get_backend() -> ArrayBackend:
return _backend
[docs]
def set_backend(name: str) -> None:
global _backend
if name == "jax":
_backend = JaxBackend()
elif name == "numpy":
_backend = NumpyBackend()
else:
raise ValueError(f"Unknown backend: {name!r}")
ops._invalidate_cache()
# Clear JIT cache so kernels recompile for the new backend
try:
from xraylabtool.calculators.core import _jit_cache
_jit_cache.clear()
except ImportError:
pass
class _OpsProxy:
"""Proxy that delegates to the active backend with method caching."""
def __getattr__(self, name: str) -> Any:
attr = getattr(_backend, name)
# Cache on the proxy instance for subsequent calls
object.__setattr__(self, name, attr)
return attr
def is_jax(self) -> bool:
"""Check whether the active backend is JAX."""
return isinstance(_backend, JaxBackend)
def _invalidate_cache(self) -> None:
"""Clear cached methods when backend changes."""
for key in list(self.__dict__):
object.__delattr__(self, key)
ops = _OpsProxy()