# ADR-001: JAX vs NumPy Computation Backend **Status:** ACCEPTED **Date:** 2026-04-06 **Deciders:** Architecture Team **Supersedes:** None --- ## Context pyXRayLabTool v0.3.0 performs all X-ray optical property calculations using NumPy arrays with SciPy for interpolation. The computation pipeline is: ``` parse_formula -> load_scattering_data -> PchipInterpolator(energy) -> calculate_scattering_factors (einsum/broadcast) -> calculate_derived_quantities (sqrt, exp, division) -> XRayResult ``` **Current performance characteristics observed in the codebase:** - `calculate_scattering_factors()` dominates runtime for energy sweeps (>100 points) - The inner loop iterates over elements, calling interpolators and accumulating weighted sums - `optimization/vectorized_core.py` already restructures this as matrix ops (`np.einsum`) - `optimization/optimized_core.py` provides monkey-patching for data loading (2-3x faster I/O) - Manual SIMD heuristics in `vectorized_core.py` (AVX detection, threshold tuning) indicate the team has hit the ceiling of what NumPy can do **What JAX offers for this workload:** 1. **JIT compilation:** Fuses the entire `scattering_factors + derived_quantities` pipeline into a single XLA program. Eliminates Python dispatch overhead and intermediate array allocations. 2. **vmap:** Replaces the explicit for-loop over materials in `compute_multiple()` and `batch_processing.py` with automatic vectorization. 3. **Automatic differentiation:** Enables future gradient-based optimization (NLSQ warm-start for Bayesian fitting per CLAUDE.md guidelines). 4. **Platform portability:** Same code runs on CPU (via XLA), GPU (CUDA/ROCm), and Apple Silicon (Metal plugin). **Which computations benefit most:** | Function | Pattern | JIT Benefit | vmap Benefit | |----------|---------|-------------|--------------| | `calculate_scattering_factors` | Matrix multiply (einsum) + broadcast | HIGH -- fuses multiply+sum+scale into single kernel | HIGH -- vectorize over materials | | `calculate_derived_quantities` | Element-wise sqrt, division, multiply | MODERATE -- eliminates 5 intermediate arrays | LOW -- already vectorized | | `PchipInterpolator.__call__` | Piecewise polynomial evaluation | HIGH -- fuses search+evaluate into single kernel | HIGH -- vectorize over elements | | `energy_to_wavelength` | Scalar division | LOW -- too simple to benefit | N/A | | `parse_formula` | Regex string parsing | NONE -- not numerical | N/A | | Multi-material batch | Loop over formulas | MODERATE | HIGH -- `vmap` eliminates loop | ## Decision **Adopt JAX as the primary computation backend, with NumPy retained at I/O boundaries.** Specifically: 1. All numerical computation in `calculators/`, `optimization/`, and `data_handling/` will use `jax.numpy` via a backend abstraction layer. 2. The `backend/` module will provide a `Protocol`-based `ArrayBackend` interface with `NumpyBackend` and `JaxBackend` implementations. 3. `NumpyBackend` will be the default during migration (zero behavior change), switchable to `JaxBackend` via `set_backend("jax")` or environment variable. 4. After Phase 2, `JaxBackend` becomes the default with `NumpyBackend` as fallback. 5. I/O operations (file loading, CSV export, pandas DataFrames) remain NumPy/pandas. 6. The `PchipInterpolator` dependency is abstracted behind `InterpolationFactory`, allowing scipy and interpax implementations to coexist. ## Consequences ### Positive - **Performance:** JIT-compiled scattering factor calculation should achieve 2-10x speedup for energy sweeps >100 points by eliminating Python dispatch and fusing operations. - **Simplification:** The entire `optimization/vectorized_core.py` module (700 lines of manual SIMD heuristics, AVX detection, contiguity checks) becomes unnecessary -- JAX's XLA compiler handles all of this automatically. - **Future capability:** `jax.grad` enables automatic differentiation through the full calculation pipeline, unlocking GPU-accelerated NLSQ fitting and Bayesian inference via NumPyro. - **Batch processing:** `jax.vmap` replaces `ThreadPoolExecutor`-based parallel processing in `batch_processing.py` with hardware-vectorized batch computation. ### Negative - **Cold start:** JIT compilation adds 100ms-2s to the first calculation call. Mitigated by AOT compilation at import time for hot-path functions. - **Debugging complexity:** JIT-compiled functions produce opaque XLA error messages. Mitigated by `jax.disable_jit()` context manager for development. - **Dependency weight:** JAX + jaxlib adds ~200MB to the install footprint. Mitigated by making JAX optional (`pip install xraylabtool[jax]`). - **Learning curve:** Team members need to understand JAX's functional purity constraints (no in-place mutation, no Python control flow in JIT). ### Risks - **interpax maturity:** The PCHIP interpolation adapter depends on interpax for JIT-compatible interpolation. If interpax proves inadequate, fallback to `jax.pure_callback` wrapping scipy (with JIT boundary penalty). - **Numerical equivalence:** JAX's XLA compiler may reorder floating-point operations, producing results that differ at the ~1e-15 level. The golden test suite with 1e-12 tolerance provides a safety net. --- ## Appendix: Functions to JIT-Compile (Priority Order) 1. `calculate_scattering_factors()` -- Hot path, called once per material per calculation 2. `calculate_derived_quantities()` -- Called immediately after scattering factors 3. `vectorized_interpolation_batch()` -- Inner loop of multi-element processing 4. `vectorized_multi_material_batch()` -- Outer batch loop (candidate for vmap) 5. `EnergyConfig.to_array()` -- Trivial but called frequently ## Appendix: Functions that MUST NOT be JIT-compiled 1. `load_scattering_factor_data()` -- File I/O, side effects (caching) 2. `create_scattering_factor_interpolators()` -- Object creation, side effects (LRU cache) 3. `parse_formula()` -- Regex string processing 4. `_warm_priority_cache()` -- Threading, side effects 5. Any function that raises `ValueError`/`FileNotFoundError` -- JAX JIT traces through exceptions