ADR-001: JAX vs NumPy Computation Backend¶
Status: ACCEPTED Date: 2026-04-06 Deciders: Architecture Team Supersedes: None
Context¶
pyXRayLabTool v0.3.0 performs all X-ray optical property calculations using NumPy arrays with SciPy for interpolation. The computation pipeline is:
parse_formula -> load_scattering_data -> PchipInterpolator(energy)
-> calculate_scattering_factors (einsum/broadcast)
-> calculate_derived_quantities (sqrt, exp, division)
-> XRayResult
Current performance characteristics observed in the codebase:
calculate_scattering_factors()dominates runtime for energy sweeps (>100 points)The inner loop iterates over elements, calling interpolators and accumulating weighted sums
optimization/vectorized_core.pyalready restructures this as matrix ops (np.einsum)optimization/optimized_core.pyprovides monkey-patching for data loading (2-3x faster I/O)Manual SIMD heuristics in
vectorized_core.py(AVX detection, threshold tuning) indicate the team has hit the ceiling of what NumPy can do
What JAX offers for this workload:
JIT compilation: Fuses the entire
scattering_factors + derived_quantitiespipeline into a single XLA program. Eliminates Python dispatch overhead and intermediate array allocations.vmap: Replaces the explicit for-loop over materials in
compute_multiple()andbatch_processing.pywith automatic vectorization.Automatic differentiation: Enables future gradient-based optimization (NLSQ warm-start for Bayesian fitting per CLAUDE.md guidelines).
Platform portability: Same code runs on CPU (via XLA), GPU (CUDA/ROCm), and Apple Silicon (Metal plugin).
Which computations benefit most:
Function |
Pattern |
JIT Benefit |
vmap Benefit |
|---|---|---|---|
|
Matrix multiply (einsum) + broadcast |
HIGH – fuses multiply+sum+scale into single kernel |
HIGH – vectorize over materials |
|
Element-wise sqrt, division, multiply |
MODERATE – eliminates 5 intermediate arrays |
LOW – already vectorized |
|
Piecewise polynomial evaluation |
HIGH – fuses search+evaluate into single kernel |
HIGH – vectorize over elements |
|
Scalar division |
LOW – too simple to benefit |
N/A |
|
Regex string parsing |
NONE – not numerical |
N/A |
Multi-material batch |
Loop over formulas |
MODERATE |
HIGH – |
Decision¶
Adopt JAX as the primary computation backend, with NumPy retained at I/O boundaries.
Specifically:
All numerical computation in
calculators/,optimization/, anddata_handling/will usejax.numpyvia a backend abstraction layer.The
backend/module will provide aProtocol-basedArrayBackendinterface withNumpyBackendandJaxBackendimplementations.NumpyBackendwill be the default during migration (zero behavior change), switchable toJaxBackendviaset_backend("jax")or environment variable.After Phase 2,
JaxBackendbecomes the default withNumpyBackendas fallback.I/O operations (file loading, CSV export, pandas DataFrames) remain NumPy/pandas.
The
PchipInterpolatordependency is abstracted behindInterpolationFactory, allowing scipy and interpax implementations to coexist.
Consequences¶
Positive¶
Performance: JIT-compiled scattering factor calculation should achieve 2-10x speedup for energy sweeps >100 points by eliminating Python dispatch and fusing operations.
Simplification: The entire
optimization/vectorized_core.pymodule (700 lines of manual SIMD heuristics, AVX detection, contiguity checks) becomes unnecessary – JAX’s XLA compiler handles all of this automatically.Future capability:
jax.gradenables automatic differentiation through the full calculation pipeline, unlocking GPU-accelerated NLSQ fitting and Bayesian inference via NumPyro.Batch processing:
jax.vmapreplacesThreadPoolExecutor-based parallel processing inbatch_processing.pywith hardware-vectorized batch computation.
Negative¶
Cold start: JIT compilation adds 100ms-2s to the first calculation call. Mitigated by AOT compilation at import time for hot-path functions.
Debugging complexity: JIT-compiled functions produce opaque XLA error messages. Mitigated by
jax.disable_jit()context manager for development.Dependency weight: JAX + jaxlib adds ~200MB to the install footprint. Mitigated by making JAX optional (
pip install xraylabtool[jax]).Learning curve: Team members need to understand JAX’s functional purity constraints (no in-place mutation, no Python control flow in JIT).
Risks¶
interpax maturity: The PCHIP interpolation adapter depends on interpax for JIT-compatible interpolation. If interpax proves inadequate, fallback to
jax.pure_callbackwrapping scipy (with JIT boundary penalty).Numerical equivalence: JAX’s XLA compiler may reorder floating-point operations, producing results that differ at the ~1e-15 level. The golden test suite with 1e-12 tolerance provides a safety net.
Appendix: Functions to JIT-Compile (Priority Order)¶
calculate_scattering_factors()– Hot path, called once per material per calculationcalculate_derived_quantities()– Called immediately after scattering factorsvectorized_interpolation_batch()– Inner loop of multi-element processingvectorized_multi_material_batch()– Outer batch loop (candidate for vmap)EnergyConfig.to_array()– Trivial but called frequently
Appendix: Functions that MUST NOT be JIT-compiled¶
load_scattering_factor_data()– File I/O, side effects (caching)create_scattering_factor_interpolators()– Object creation, side effects (LRU cache)parse_formula()– Regex string processing_warm_priority_cache()– Threading, side effectsAny function that raises
ValueError/FileNotFoundError– JAX JIT traces through exceptions