Source code for xraylabtool.calculators.cache

"""Cache infrastructure for scattering factor data and interpolators."""

# ruff: noqa: PLW0603

from __future__ import annotations

from functools import cache, lru_cache
import types
from typing import TYPE_CHECKING, Any

from xraylabtool.backend import InterpolationFactory, ops

if TYPE_CHECKING:
    from xraylabtool.typing_extensions import InterpolatorProtocol


# Module-level cache for f1/f2 scattering tables, keyed by element symbol
# Using Any to avoid early pandas import
_scattering_factor_cache: dict[str, Any] = {}

# Module-level cache for interpolators to avoid repeated creation
if TYPE_CHECKING:
    _interpolator_cache: dict[
        str, tuple[InterpolatorProtocol, InterpolatorProtocol]
    ] = {}
else:
    _interpolator_cache: dict[str, Any] = {}

# Cache for most commonly used elements to improve cold start performance
_PRIORITY_ELEMENTS = ["H", "C", "N", "O", "Si", "Al", "Ca", "Fe", "Cu", "Zn"]
_CACHE_WARMED = False

# Atomic data cache for bulk lookups
_atomic_data_cache: dict[str, dict[str, float]] = {}


[docs] def get_cached_elements() -> list[str]: """ Get list of elements currently cached in the scattering factor cache. Returns: List of element symbols currently loaded in cache """ return list(_scattering_factor_cache.keys())
[docs] @cache def get_bulk_atomic_data( elements_tuple: tuple[str, ...], ) -> dict[str, types.MappingProxyType[str, float]]: """ Bulk load atomic data for multiple elements with high-performance caching. This optimization uses a preloaded cache of common elements to eliminate expensive database queries to the Mendeleev library during runtime. Args: elements_tuple: Tuple of element symbols to load data for Returns: Dictionary mapping element symbols to their atomic data """ from xraylabtool.data_handling.atomic_cache import get_bulk_atomic_data_fast return get_bulk_atomic_data_fast(elements_tuple)
def _warm_priority_cache() -> None: """ Warm the cache with priority elements for improved cold start performance. This is called automatically on first calculation to pre-load common elements. Uses background thread for async warming to reduce cold start penalty. """ global _CACHE_WARMED if _CACHE_WARMED: return # Use background thread for async warming to avoid blocking main thread import threading def _background_cache_warming(): # type: ignore[no-untyped-def] """Background thread function for cache warming.""" global _CACHE_WARMED try: from xraylabtool.data_handling.atomic_cache import ( get_bulk_atomic_data_fast, ) priority_tuple = tuple(_PRIORITY_ELEMENTS) get_bulk_atomic_data_fast(priority_tuple) _CACHE_WARMED = True except Exception: # noqa: S110 # If warming fails, just continue - it's not critical pass # Start background warming but don't wait for it warming_thread = threading.Thread(target=_background_cache_warming, daemon=True) warming_thread.start() # Mark as "warming in progress" to avoid multiple attempts _CACHE_WARMED = True def _smart_cache_warming(formula: str) -> None: """ Smart cache warming that only loads elements needed for the specific calculation. Args: formula: Chemical formula to analyze for required elements (e.g., "SiO2") """ try: from xraylabtool.utils import parse_formula # Parse formula to get required elements element_symbols, _ = parse_formula(formula) required_elements = element_symbols # Load only required elements (much faster than bulk loading) from xraylabtool.data_handling.atomic_cache import get_bulk_atomic_data_fast get_bulk_atomic_data_fast(tuple(required_elements)) # Mark cache as warmed global _CACHE_WARMED _CACHE_WARMED = True except Exception: # If smart warming fails, fall back to traditional warming _warm_priority_cache()
[docs] def clear_scattering_factor_cache() -> None: """ Clear the module-level scattering factor cache. This function removes all cached scattering factor data from memory. Useful for testing or memory management. """ global _CACHE_WARMED _scattering_factor_cache.clear() _interpolator_cache.clear() _atomic_data_cache.clear() _CACHE_WARMED = False # Clear LRU caches get_bulk_atomic_data.cache_clear() create_scattering_factor_interpolators.cache_clear()
[docs] def is_element_cached(element: str) -> bool: """ Check if scattering factor data for an element is already cached. Args: element: Element symbol to check Returns: True if element data is cached, False otherwise """ return element.capitalize() in _scattering_factor_cache
[docs] @lru_cache(maxsize=128) def create_scattering_factor_interpolators( element: str, ) -> tuple[InterpolatorProtocol, InterpolatorProtocol]: """ Create PCHIP interpolators for f1 and f2 scattering factors. This helper function loads scattering factor data for a specific element and returns two callable PCHIP interpolator objects for f1 and f2 that behave identically to Julia interpolation behavior. Args: element: Element symbol (e.g., 'H', 'C', 'N', 'O', 'Si', 'Ge') Returns: Tuple of (f1_interpolator, f2_interpolator) where each is a callable that takes energy values and returns interpolated scattering factors Raises: FileNotFoundError: If the .nff file for the element is not found ValueError: If the element symbol is invalid or data is insufficient Examples: >>> from xraylabtool.calculators.cache import create_scattering_factor_interpolators >>> import numpy as np >>> f1_interp, f2_interp = create_scattering_factor_interpolators('Si') >>> energy = 100.0 # eV >>> f1_value = f1_interp(energy) >>> isinstance(f1_value, (int, float, np.number, np.ndarray)) True >>> f2_value = f2_interp(energy) >>> isinstance(f2_value, (int, float, np.number, np.ndarray)) True >>> # Can also handle arrays >>> energies = np.array([100.0, 200.0, 300.0]) >>> f1_values = f1_interp(energies) >>> len(f1_values) == 3 True """ from xraylabtool.calculators.scattering_data import load_scattering_factor_data # Check interpolator cache first if element in _interpolator_cache: return _interpolator_cache[element] # Load scattering factor data scattering_factor_data = load_scattering_factor_data(element) # Verify we have sufficient data points for PCHIP interpolation if len(scattering_factor_data) < 2: raise ValueError( f"Insufficient data points for element '{element}'. " "PCHIP interpolation requires at least 2 points, " f"found {len(scattering_factor_data)}." ) # Extract energy, f1, and f2 data energy_values = ops.asarray(scattering_factor_data["E"].values) f1_values = ops.asarray(scattering_factor_data["f1"].values) f2_values = ops.asarray(scattering_factor_data["f2"].values) # Verify energy values are sorted (PCHIP requires sorted x values) if ops.any(ops.asarray(energy_values[:-1]) > ops.asarray(energy_values[1:])): # Sort the data if it's not already sorted sort_indices = ops.argsort(energy_values) energy_values = energy_values[sort_indices] f1_values = f1_values[sort_indices] f2_values = f2_values[sort_indices] # Create PCHIP interpolators via backend abstraction layer f1_interpolator = InterpolationFactory.create_pchip( energy_values, f1_values, extrapolate=False ) f2_interpolator = InterpolationFactory.create_pchip( energy_values, f2_values, extrapolate=False ) # Cache the interpolators for future use _interpolator_cache[element] = (f1_interpolator, f2_interpolator) return f1_interpolator, f2_interpolator