ADR-003: float64 Preservation Policy

Status: ACCEPTED Date: 2026-04-06 Deciders: Architecture Team Supersedes: None


Context

The float64 Requirement in X-ray Calculations

pyXRayLabTool calculates X-ray optical properties that span many orders of magnitude:

Quantity

Typical Range

Precision Needed

Dispersion delta

1e-5 to 1e-8

Relative 1e-6

Absorption beta

1e-6 to 1e-10

Relative 1e-6

Attenuation length (cm)

1e-4 to 1e+2

Relative 1e-4

Critical angle (degrees)

0.01 to 1.0

Absolute 1e-4

SLD (Ang^-2)

1e-5 to 1e-6

Relative 1e-6

Energy (eV)

30 to 30000

Absolute 0.1

Wavelength (m)

4e-11 to 4e-8

Relative 1e-8

The computation involves products of physical constants spanning 30+ orders of magnitude:

# From constants.py
THOMPSON = 2.8179403227e-15    # meters
AVOGADRO = 6.02214199e23       # mol^-1
SCATTERING_FACTOR = THOMPSON * AVOGADRO * 1e6 / (2 * pi)  # ~2.7e14

# From core.py calculate_scattering_factors():
common_factor = SCATTERING_FACTOR * mass_density / molecular_weight
wave_sq = wavelength ** 2    # ~(1e-10)^2 = 1e-20
dispersion = wave_sq * common_factor * f1_total
# = 1e-20 * 2.7e14 * density/MW * f1 -> ~1e-6

The intermediate product wave_sq * common_factor is ~1e-6, and the final dispersion is also ~1e-6. With float32 (7 decimal digits of precision), the relative error in dispersion would be ~1e-1, which is catastrophically wrong. float64 is non-negotiable for correctness.

JAX float32 Default

JAX defaults to float32 for all operations. This is a deliberate design choice for ML workloads where float32 is sufficient and 2x faster. However, JAX provides a configuration flag:

jax.config.update("jax_enable_x64", True)  # Enable float64 globally

When enabled, jnp.array([1.0]) creates a float64 array, matching NumPy behavior.

Current Codebase float64 Usage

The codebase is consistently float64:

# core.py -- explicit float64 everywhere
np.zeros(n_energies, dtype=np.float64, order="C")
np.empty((n_elements, n_energies), dtype=np.float64, order="C")
np.asarray(f1_values, dtype=np.float64)

# vectorized_core.py
np.empty((n_elements, n_energies), dtype=np.float64, order="C")

# batch_processing.py
np.array([float(energies)], dtype=np.float64)

All array types in typing_extensions.py are defined as NDArray[np.float64].

Decision

Enforce float64 globally via JAX configuration at application startup. Treat float32 computation as a correctness bug.

Implementation:

  1. Set jax.config.update("jax_enable_x64", True) in xraylabtool/__init__.py before any JAX imports.

  2. Add a runtime assertion in backend/jax_impl.py that verifies x64 mode is enabled.

  3. All explicit dtype=np.float64 specifications in the codebase are preserved (they become dtype=jnp.float64).

  4. The validate_energy_array() function in typing_extensions.py already checks dtype in [np.float32, np.float64] – tighten this to float64-only.

  5. Add a CI check that runs the test suite with JAX_ENABLE_X64=0 and expects failures (proving that float64 is correctly enforced).

Consequences

Positive

  • Correctness guarantee: No possibility of silent float32 precision loss in X-ray calculations.

  • Behavioral parity: jnp.array([1.0]).dtype == jnp.float64, identical to NumPy.

  • Explicit contract: The dtype policy is documented and enforced, not implicit.

Negative

  • Performance cost: float64 operations are ~2x slower than float32 on GPU and ~1.3x slower on CPU. For this workload (X-ray property calculations, not ML training), this is an acceptable tradeoff – correctness is Priority 1 per CLAUDE.md.

  • Memory cost: float64 arrays use 2x the memory of float32. For typical energy sweeps (50-1000 points, 1-10 materials), this means kilobytes, not a concern.

  • JAX ecosystem friction: Some JAX libraries default to float32. Any third-party JAX code must be validated for float64 compatibility.

Enforcement Strategy

# xraylabtool/__init__.py (FIRST LINES, before any other imports)
import os
os.environ.setdefault("JAX_ENABLE_X64", "1")

# Verify enforcement at import time
def _verify_float64():
    """Verify that JAX float64 mode is enabled."""
    try:
        import jax
        import jax.numpy as jnp
        if not jax.config.jax_enable_x64:
            raise RuntimeError(
                "JAX float64 mode is required for X-ray calculations. "
                "Set JAX_ENABLE_X64=1 or call "
                "jax.config.update('jax_enable_x64', True)"
            )
        # Verify actual behavior
        test_val = jnp.array(1.0)
        if test_val.dtype != jnp.float64:
            raise RuntimeError("JAX float64 mode enabled but arrays are float32")
    except ImportError:
        pass  # JAX not installed, numpy backend handles this

_verify_float64()

Appendix: Numerical Precision Requirements by Function

Function

Min Precision Required

float32 OK?

float64 OK?

energy_to_wavelength

1e-8 relative

NO (6e-7 range)

YES

calculate_scattering_factors

1e-12 relative

NO (1e-7 intermediate products)

YES

calculate_derived_quantities

1e-8 relative

NO (sqrt of 1e-8 values)

YES

calculate_critical_angle

1e-4 absolute

MARGINAL

YES

calculate_attenuation_length

1e-4 relative

NO (division by 1e-10)

YES

parse_formula

Exact (integer)

N/A

N/A

PchipInterpolator

1e-6 relative

NO (cubic polynomial coefficients)

YES

Conclusion: No function in the calculation pipeline can safely use float32.