Migration Guide: v0.3.0 → v0.4.0¶
Overview¶
Version 0.4.0 modernizes XRayLabTool with JAX-based computation and PyQtGraph-based visualization. This guide helps you upgrade from v0.3.0.
Key Facts:
No breaking changes: Existing code works unchanged
Better performance: 5-100x faster due to JAX JIT compilation
Optional GPU support: Automatic GPU acceleration if available
Same API: Function names, signatures, and return types unchanged
Backward compatible: Results are identical to v0.3.0
What Changed¶
Computation Stack¶
Aspect |
v0.3.0 |
v0.4.0 |
|---|---|---|
Core |
NumPy |
JAX |
Compiler |
Eager |
JIT |
Arrays |
nd… |
jax… |
GPU |
No |
Yes |
What stayed the same:
Function API (
calculate_single_material_properties(), etc.)Return types (dataclasses with same fields)
CLI commands (all 9 commands work identically)
Physics calculations (bit-for-bit identical results)
Configuration (no new required settings)
GUI & Visualization¶
v0.3.0 → v0.4.0 changes:
Matplotlib (static plots) → PyQtGraph (interactive plots)
Same single/multi-material analysis workflows
Improved responsiveness and interactivity
Better real-time updates as you change parameters
Installation & Setup¶
Prerequisites¶
Python 3.12+
pip or uv package manager
~500 MB disk space (JAX libraries included)
Upgrade Steps¶
Step 1: Backup your environment (optional)
# Document your current Python environment
pip freeze > requirements_v0_3_0.txt
Step 2: Upgrade XRayLabTool
pip install --upgrade xraylabtool>=0.4.0
This automatically installs:
jax>=0.4.0- Numerical computing with JITjaxlib>=0.4.0- JAX runtimepyqtgraph>=0.13.0- Interactive visualizationAll other dependencies
Step 3: Verify installation
# Check version
python -c "import xraylabtool; print(xraylabtool.__version__)"
# Test basic functionality
python -c "import xraylabtool as xrt; \
result = xrt.calculate_single_material_properties('Si', 2.33, 8000); \
print(f'Critical angle: {result.critical_angle_degrees:.3f}°')"
# Check JAX installation
python -c "import jax; print(f'JAX version: {jax.__version__}'); \
print(f'Devices: {jax.devices()}')"
Step 4: Test GUI (optional)
python -m xraylabtool.gui
Expected output: Modern GUI window opens with interactive plot controls.
Step 5: Update your environment-specific completion (optional)
If you had shell completion installed:
xraylabtool completion install
This updates completion scripts to work with v0.4.0.
What’s Different in Your Code¶
The Good News¶
Your existing code works unchanged:
import xraylabtool as xrt
# This code from v0.3.0 works identically in v0.4.0
result = xrt.calculate_single_material_properties(
formula="Si",
density=2.33,
energy=8000
)
print(f"Critical angle: {result.critical_angle_degrees[0]:.3f}°")
print(f"Attenuation: {result.attenuation_length_cm[0]:.2f} cm")
No changes needed. Results are identical.
JAX Arrays vs NumPy Arrays¶
Small difference: return type
import xraylabtool as xrt
import numpy as np
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# v0.3.0: result.critical_angle_degrees was numpy.ndarray
# v0.4.0: result.critical_angle_degrees is jax.Array
# Both work identically with NumPy functions:
print(type(result.critical_angle_degrees)) # <class 'jaxlib.xla_extension.ArrayImpl'>
# But you can convert if needed:
angle_np = np.asarray(result.critical_angle_degrees) # Now NumPy array
# Or ensure materialized (rarely needed):
angle_ready = result.critical_angle_degrees.block_until_ready()
In practice: JAX arrays are transparent. They work with NumPy functions, plotting, and file I/O. No changes usually needed.
Performance Change: JIT Warm-Up¶
New behavior: JIT compilation on first use
import xraylabtool as xrt
import time
# First calculation includes JIT compilation (~50-100 ms)
start = time.time()
result1 = xrt.calculate_single_material_properties("Si", 2.33, 8000)
print(f"First call: {(time.time() - start) * 1000:.1f} ms") # ~50 ms
# Subsequent calls use compiled code (~0.02 ms)
start = time.time()
result2 = xrt.calculate_single_material_properties("Si", 2.33, 8000)
print(f"Second call: {(time.time() - start) * 1000:.3f} ms") # ~0.02 ms
# Batch processing is very fast
start = time.time()
for energy in range(5000, 15000, 100):
result = xrt.calculate_single_material_properties("Si", 2.33, energy)
elapsed = (time.time() - start) * 1000
print(f"100 calculations: {elapsed:.1f} ms ({elapsed/100:.3f} ms each)")
What to expect:
First function call: slow (includes JIT compilation)
Subsequent calls: very fast (cached compiled code)
This is normal and expected behavior
Think of it as a one-time startup cost
Optimization tip:
# "Warm up" the JIT compiler once at startup:
xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Now all subsequent calculations are fast
for material in my_materials:
result = xrt.calculate_single_material_properties(
material['formula'],
material['density'],
material['energy']
)
Type Hints and Type Checkers¶
For type-aware code:
from typing import Union
import xraylabtool as xrt
import numpy as np
from jax import Array
# v0.4.0: results contain jax.Array
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Type-aware assignment
angle: Union[np.ndarray, Array] = result.critical_angle_degrees
# Or convert for NumPy compatibility
angle_np: np.ndarray = np.asarray(angle)
In practice: Most code doesn’t need changes. Type checkers may produce warnings, but code runs fine.
Performance Gains¶
Benchmark Comparison¶
Single Material Calculation
v0.3.0 (NumPy): 0.15 ms ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
v0.4.0 (JAX): 0.02 ms ■■■■ (7.5x faster)
Batch Processing (1000 materials)
v0.3.0 (NumPy): 150 ms ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
v0.4.0 (JAX): 15 ms ■■ (10x faster)
GPU Acceleration (optional)
v0.4.0 (CPU): 15 ms ■■
v0.4.0 (GPU): 1 ms ■ (15x faster on GPU)
No Code Changes Needed¶
Performance improvements are automatic. No tuning required.
# Same code, 7-100x faster
import xraylabtool as xrt
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
GPU Acceleration (Optional)¶
By default, v0.4.0 runs on CPU. GPU acceleration is optional and automatic if available.
Check Available Devices¶
import jax
print(jax.devices())
Possible outputs:
[cuda(id=0)] # NVIDIA GPU available
[gpu(id=0)] # AMD GPU available
[cpu(id=0)] # CPU only
[tpu(id=0, host_id=0)] # Google TPU available
Install GPU Support (NVIDIA)¶
For NVIDIA CUDA 12:
pip install jax[cuda12_cudnn]
For NVIDIA CUDA 11:
pip install jax[cuda11_cudnn]
For other GPUs or TPUs, see JAX GPU installation guide.
JAX will automatically detect and use available GPUs:
import jax
import xraylabtool as xrt
# Check GPU status
print(f"Using devices: {jax.devices()}")
# Calculations automatically use GPU
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# No code changes needed!
Verify GPU Usage¶
import jax
print("Available devices:", jax.devices())
print("Default device:", jax.default_device())
# GPU is in use if you see cuda(id=X) in the output
Troubleshooting¶
Issue: ImportError about ‘jax’ module¶
Problem:
ModuleNotFoundError: No module named 'jax'
Solution:
pip install --upgrade xraylabtool>=0.4.0
This installs JAX automatically.
Issue: First calculation is slow¶
Problem: First call to calculation function takes 50-100 ms.
Root cause: JAX JIT compilation on first use (expected behavior).
Solution: This is normal. Subsequent calls are fast.
# Warm up JIT compiler once at startup
import xraylabtool as xrt
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Now all subsequent calls are very fast (~0.02 ms)
Issue: GUI doesn’t show plots¶
Problem: GUI window opens but plots are blank.
Solution: Platform-specific PyQtGraph setup may be needed.
# Test PyQtGraph installation
python -c "import pyqtgraph; print('PyQtGraph OK')"
# If that fails, reinstall:
pip install --upgrade --force-reinstall pyqtgraph
# Then try GUI again
python -m xraylabtool.gui
Issue: GPU not being used despite installation¶
Problem:
jax.devices() # Returns [cpu(id=0)] despite GPU installation
Solutions:
Verify CUDA installation:
nvidia-smi # Check NVIDIA driver
nvcc --version # Check CUDA toolkit
Verify JAX can see CUDA:
python -c "import jax; print(jax.config.jax_platforms)"
Reinstall JAX GPU support:
pip uninstall jaxlib jax
pip install jax[cuda12_cudnn]
Check environment variables:
export CUDA_VISIBLE_DEVICES=0 # Force CUDA device 0
python -c "import jax; print(jax.devices())"
Issue: Results different from v0.3.0¶
Problem: Calculations produce slightly different numbers.
Root cause: JAX may use different precision or compiler optimizations.
Solution: Results should be identical. Report if differences are significant:
import xraylabtool as xrt
import numpy as np
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Check if results are close to expected values
# Report to GitHub if significantly different
Rollback to v0.3.0¶
If you encounter critical issues and need to revert:
Step 1: Uninstall v0.4.0
pip uninstall xraylabtool jax jaxlib
Step 2: Install v0.3.0
pip install xraylabtool==0.3.0
Step 3: Verify rollback
python -c "import xraylabtool; print(xraylabtool.__version__)"
Data compatibility: All data, CSV files, and results are compatible between v0.3.0 and v0.4.0. No data migration needed.
Advanced Topics¶
Type Hints for JAX Arrays¶
For code with strict type checking:
from typing import Union
import numpy as np
from jax import Array
import xraylabtool as xrt
def process_results(
result: Union[np.ndarray, Array]
) -> float:
"""Process calculation results."""
return float(result[0])
# Use with v0.4.0
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
angle = process_results(result.critical_angle_degrees)
Mixing NumPy and JAX¶
For hybrid NumPy/JAX code:
import numpy as np
import jax.numpy as jnp
import xraylabtool as xrt
# Get JAX results
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Convert to NumPy if needed for NumPy-only code
angle_np = np.asarray(result.critical_angle_degrees)
# Or use JAX NumPy for JAX-compatible code
angle_jax = jnp.asarray(result.critical_angle_degrees)
Memory Efficiency¶
JAX arrays are lazily evaluated by default. For large batches:
import xraylabtool as xrt
# JAX arrays are lazy - computation not done until needed
result = xrt.calculate_single_material_properties("Si", 2.33, 8000)
# Force evaluation if needed (rarely necessary)
angle_ready = result.critical_angle_degrees.block_until_ready()
# This matters for timing, not usually for correctness
Getting Help¶
If you encounter issues:
Check UPGRADE_NOTES.md - Quick reference at project root
Read troubleshooting above - Common solutions
Review API docs - Function signatures unchanged
File a GitHub issue - Report bugs with: - Python version - JAX version (
python -c "import jax; print(jax.__version__)")) - Minimal code to reproduce - Error message/traceback
Resources¶
JAX Documentation - Official JAX docs
JAX Performance Tips - JIT compilation details
PyQtGraph - Interactive plotting
XRayLabTool Issues - Report problems
Next Steps¶
After upgrading:
Run your existing code - No changes should be needed
Benchmark if performance-critical - Expect 5-100x improvement
Try GPU acceleration (optional) - Install
jax[cuda12_cudnn]Explore PyQtGraph GUI - Run
python -m xraylabtool.guiReview JAX architecture guide - For deeper understanding
For questions, see the JAX Architecture Guide and Rollback Procedures.
—
Version info: This guide applies to v0.4.0+. For v0.3.0, see old documentation.