JAX Architecture & Design¶
Audience: Developers and users wanting deep understanding of v0.4.0 design.
Purpose: Explain why JAX, how it works, and how XRayLabTool uses it.
Why JAX?¶
What is JAX?¶
JAX is a Python library for numerical computing with:
JIT (Just-In-Time) Compilation - Convert Python → compiled machine code
Automatic Differentiation - Compute gradients with
grad()Vectorization - Apply functions across batches with
vmap()GPU/TPU Support - Automatic device acceleration
XRayLabTool uses features 1 and 3; features 2 and 4 enable future scientific workflows.
NumPy vs JAX¶
import numpy as np
import jax
import jax.numpy as jnp
# NumPy: eager evaluation
x_np = np.array([1.0, 2.0, 3.0])
y_np = np.sin(x_np) # Computed immediately
print(y_np) # [0.841, 0.909, 0.141]
# JAX: lazy evaluation + JIT
x_jax = jnp.array([1.0, 2.0, 3.0])
y_jax = jnp.sin(x_jax) # Not computed yet, just recorded
# Explicit compilation
sin_jit = jax.jit(jnp.sin)
y_compiled = sin_jit(x_jax) # Computed and cached
Performance Benefits¶
1. JIT Compilation: 10-100x speedup
NumPy (eager): ■■■■■■■■■■■■■■■■■■■■ (1.0x baseline)
JAX (eager): ■■■■■■■■■■■■■■■■■■■■ (0.95x, ~same)
JAX (JIT compiled): ■ (0.01-0.1x, 10-100x faster)
2. Vectorization: Automatic batching with vmap()
# NumPy: manual loop
energies = np.linspace(1000, 20000, 1000)
results = []
for e in energies:
result = calculate_property(e) # 1000 calls
results.append(result)
# JAX: automatic vectorization
calculate_batch = jax.vmap(calculate_property)
results = calculate_batch(energies) # Single compiled call
3. GPU Acceleration: Automatic offloading
CPU: ■■■■■■■■■■■■■■■■■■■■ (1.0x baseline)
GPU (V100): ■ (0.02x, 50x faster)
GPU (A100): ■ (0.01x, 100x faster)
Design Philosophy¶
JAX adoption strategy for XRayLabTool:
Transparent to users - API stays the same
Backward compatible - v0.3.0 code works unchanged
Gradual adoption - Don’t rewrite everything at once
Future-proof - Enable autodiff and hardware acceleration later
JAX Fundamentals¶
JIT Compilation¶
What is JIT?
JIT converts Python functions into compiled machine code:
import jax
import jax.numpy as jnp
# Normal function (eager evaluation)
def compute(x):
return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
# First call: Python interpreter runs
result1 = compute(1.0) # ~0.1 ms (with overhead)
# Compiled version
compute_jit = jax.jit(compute)
# First call: JAX compiles function (expensive)
result2 = compute_jit(1.0) # ~50 ms (includes compilation)
# Subsequent calls: use cached compiled code (fast)
result3 = compute_jit(1.0) # ~0.001 ms
Cost-Benefit:
Compilation cost: 50-100 ms (one-time)
Per-call speedup: 100-1000x
Break-even: 100+ calls
Why this matters for XRayLabTool:
Calculation functions run thousands of times (batch processing)
One-time compilation cost is negligible
Per-call speedup dominates overall performance
Traces and Shapes¶
JAX compiles based on input shapes, not values:
import jax
import jax.numpy as jnp
@jax.jit
def process(x):
return jnp.sum(x) * 2
# First call with shape (3,)
result1 = process(jnp.array([1.0, 2.0, 3.0]))
# JAX traces and compiles for shape (3,)
# Second call with shape (3,) - reuses compilation
result2 = process(jnp.array([4.0, 5.0, 6.0]))
# Cache hit, fast
# Third call with shape (5,) - recompiles!
result3 = process(jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]))
# New shape = recompilation (slow)
Implication for XRayLabTool:
Single material calculations always use same shape
Batch processing with consistent batch sizes benefits from caching
Different input sizes trigger recompilation
Automatic Differentiation¶
Future capability (enabled by JAX, not yet used):
import jax
import jax.numpy as jnp
def scattering_factor(energy):
"""Calculate f1 component."""
return jnp.exp(-energy / 1000) # Simplified
# Compute gradient dβ/dE
gradient_fn = jax.grad(scattering_factor)
dfdE = gradient_fn(8000.0) # Automatic differentiation
Future use cases:
Optimization (find optimal angles, densities)
Uncertainty quantification
Inverse problems (given results, find parameters)
Vectorization with vmap¶
Automatic batching - future optimization:
import jax
import jax.numpy as jnp
def calculate_one(energy):
"""Single energy calculation."""
return energy ** 2
# Vectorized version
calculate_batch = jax.vmap(calculate_one)
energies = jnp.array([1000, 5000, 8000, 10000])
results = calculate_batch(energies) # Compiled as single operation
Current approach (v0.4.0):
XRayLabTool manually implements batching, which JAX compiles with JIT.
Future optimization (v0.5+):
Use vmap() for automatic vectorization without manual loops.
XRayLabTool Implementation¶
Architecture Overview¶
User Code
│
├─→ API Layer (xraylabtool.__init__)
│ └─→ validate inputs
│ convert to JAX arrays
│
├─→ Calculation Layer (calculators/)
│ └─→ @jax.jit compiled functions
│ JAX array operations
│ physics calculations
│
├─→ Data Layer (data_handling/)
│ └─→ atomic data cache
│ JAX array storage
│ efficient lookups
│
└─→ Output
└─→ XRayResult dataclass
JAX arrays for properties
Compilation in XRayLabTool¶
Key functions decorated with @jax.jit:
# In xraylabtool/calculators/core.py
@jax.jit
def calculate_refraction_indices(energy, delta, beta):
"""Core JIT-compiled calculation."""
n = 1 - delta - 1j * beta
return n
@jax.jit
def calculate_critical_angle(delta):
"""JIT-compiled critical angle."""
return jnp.sqrt(2 * delta)
Warm-up pattern:
# First call triggers compilation
result = calculate_single_material_properties("Si", 2.33, 8000)
# ~50 ms (includes JIT compilation for shape/type)
# Subsequent calls use cached compilation
result = calculate_single_material_properties("Si", 2.33, 8000)
# ~0.02 ms (no recompilation)
Data Flow¶
v0.4.0 data types:
# User input: Python/NumPy
formula = "SiO2" # str
density = 2.2 # float
energy = 8000 # float
# Internal: JAX arrays
|-> Parse formula -> JAX array of atomic numbers
|-> Load atomic data -> JAX arrays
|-> JIT compute -> JAX array results
\-> Package -> XRayResult with JAX arrays
# Output: JAX arrays (NumPy compatible)
result.critical_angle_degrees # jax.Array (shape: (1,))
result.attenuation_length_cm # jax.Array (shape: (1,))
Memory Management¶
JAX vs NumPy memory semantics:
import jax.numpy as jnp
# JAX arrays are immutable
x = jnp.array([1.0, 2.0, 3.0])
y = x + 1 # Creates new array, doesn't modify x
# Lazy evaluation
z = jnp.sin(x) # Computation recorded, not executed yet
# Materialization
z_ready = z.block_until_ready() # Forces computation, blocks until done
For XRayLabTool users:
No memory leaks (immutability prevents issues)
Minimal memory overhead (lazy evaluation)
Transparent memory management (automatic)
GPU Acceleration¶
Automatic device placement:
import jax
# JAX automatically detects and uses available devices
print(jax.devices()) # Shows available hardware
# No code changes needed
result = calculate_single_material_properties("Si", 2.33, 8000)
# Automatically runs on GPU if available
Device hints (advanced):
import jax
# Force CPU
with jax.default_device(jax.devices("cpu")[0]):
result = calculate(...)
# Force GPU
with jax.default_device(jax.devices("gpu")[0]):
result = calculate(...)
Type System¶
JAX array types:
import jax.numpy as jnp
import numpy as np
x_jax = jnp.array([1.0, 2.0])
x_np = np.array([1.0, 2.0])
print(type(x_jax)) # <class 'jaxlib.xla_extension.ArrayImpl'>
print(type(x_np)) # <class 'numpy.ndarray'>
# NumPy operations work with JAX arrays
np.sin(x_jax) # Works!
np.concatenate([x_jax, x_np]) # Works!
# Type hints
from typing import Union
from jax import Array
def process(x: Union[np.ndarray, Array]) -> float:
return float(jnp.mean(x))
Configuration & Control¶
JAX Configuration Options¶
For advanced users:
import jax
# Disable JIT (useful for debugging)
jax.config.update("jax_disable_jit", True)
# Disable GPU (force CPU)
jax.config.update("jax_platforms", "cpu")
# Precision control
jax.config.update("jax_default_float_dtype", jnp.float32)
In XRayLabTool code:
# Environment variables
JAX_DISABLE_JIT=1 python script.py # Debug without JIT
JAX_PLATFORMS=cpu python script.py # Force CPU
Common Gotchas & Solutions¶
Problem: Shape-based recompilation¶
Issue:
@jax.jit
def calculate(energies):
return jnp.sin(energies)
# Recompilation on every call
for energy_list in many_lists:
result = calculate(energy_list) # Different sizes = recompile!
Solution:
# Pad to fixed size
max_size = 1000
def pad(x):
return jnp.pad(x, (0, max_size - len(x)))
energy_fixed = pad(energy_list)
result = calculate(energy_fixed) # Single compilation
Problem: Python control flow¶
Issue:
@jax.jit
def conditional_calc(x, use_gpu):
if use_gpu: # Won't work! (Python control flow)
return gpu_calc(x)
else:
return cpu_calc(x)
Solution:
@jax.jit
def conditional_calc(x, use_gpu):
return jnp.where(use_gpu, gpu_calc(x), cpu_calc(x))
Problem: Non-array inputs¶
Issue:
@jax.jit
def calc(material_name, energy):
data = MATERIAL_DATA[material_name] # Dict lookup won't trace
return compute(data, energy)
Solution:
# Move non-array inputs outside @jax.jit
def calc(material_name, energy):
data = MATERIAL_DATA[material_name] # Python (outside JIT)
return compute_jit(data, energy) # JAX (inside JIT)
@jax.jit
def compute_jit(data, energy):
return jnp.sin(energy) * data
Testing with JAX¶
Unit Testing Considerations¶
import jax
import jax.numpy as jnp
import pytest
def test_calculation_matches_reference():
"""JAX results match expected output."""
result = calculate_single_material_properties("Si", 2.33, 8000)
angle = float(result.critical_angle_degrees[0])
assert abs(angle - 0.158) < 0.001 # Allow small tolerance
def test_jit_consistency():
"""JIT and eager evaluation match."""
# Compile once
compute_jit = jax.jit(core_compute)
x = jnp.array([1000, 5000, 8000])
result_eager = core_compute(x) # Eager (no JIT)
result_jit = compute_jit(x) # Compiled
assert jnp.allclose(result_eager, result_jit)
Debugging Strategies¶
import jax
# Disable JIT for debugging
jax.config.update("jax_disable_jit", True)
# Add prints (with caveats)
@jax.jit
def debug_calc(x):
print(x.shape) # Prints shape at compile time
return x ** 2
# Use jax.debug.print for runtime prints
from jax import debug
@jax.jit
def debug_calc2(x):
debug.print("Value: {}", x)
return x ** 2
Performance Profiling¶
import time
import jax
import jax.numpy as jnp
# Warm up JIT
compute_fn = jax.jit(your_function)
compute_fn(example_input)
# Profile
start = time.time()
for _ in range(1000):
result = compute_fn(input)
elapsed = (time.time() - start) / 1000
print(f"Time per call: {elapsed * 1000:.3f} ms")
Future Directions¶
Potential JAX Features for v0.5+¶
1. Autodiff for Optimization
from jax import grad
def objective(angle):
"""Minimize deviation from target angle."""
result = calculate_single_material_properties("Si", 2.33, angle)
return (result.critical_angle_degrees[0] - target) ** 2
gradient = grad(objective)
optimal_angle = optimize(objective, gradient, initial_guess)
2. vmap for Automatic Batching
from jax import vmap
# Auto-vectorize over materials
calc_batch = vmap(calculate_single_material_properties, in_axes=(0, 0, None))
results = calc_batch(formulas, densities, energy)
3. Distributed Processing
from jax.experimental import maps
# Shard across multiple GPUs
results = maps.pmap(calculate_single_material_properties)(
sharded_materials, densities, energy
)
Further Reading¶
Conclusion¶
JAX modernizes XRayLabTool with:
JIT Compilation: 10-100x speedup with no code changes
GPU Acceleration: Automatic hardware acceleration
Backward Compatibility: Existing code works unchanged
Future Capability: Autodiff and advanced optimization ready
The migration is transparent to users while enabling significant performance improvements and future scientific computing features.
—
Next: See Migration Guide for upgrade instructions.