Installation¶
Requirements¶
Python 3.12 or higher (3.8-3.11 NOT supported)
JAX >= 0.8.3
jaxlib >= 0.8.3 (must be compatible with JAX version)
NumPy >= 2.3.5
SciPy >= 1.17.0
NLSQ >= 0.6.10 (GPU-accelerated optimization)
NumPyro >= 0.20.0 (Bayesian inference)
ArviZ >= 0.23.4 (Bayesian visualization)
Basic Installation¶
Install from PyPI:
pip install rheojax
Development Installation¶
For development, clone the repository and install with uv:
git clone https://github.com/imewei/rheojax.git
cd rheojax
uv sync # Install all dependencies from uv.lock
pre-commit install # Set up pre-commit hooks
Optional Dependencies¶
GPU Support (Linux + CUDA 12+ or 13+)¶
GPU acceleration provides 20-100x speedup for large datasets.
Platform Requirements:
✅ Linux + NVIDIA GPU + CUDA 12.1+: Full GPU acceleration
✅ Linux + NVIDIA GPU + CUDA 13.x: Full GPU acceleration with latest toolkit
❌ macOS: CPU-only (Apple Silicon/Intel, no NVIDIA GPU support)
❌ Windows: CPU-only (CUDA support experimental/unstable)
Quick Install (from repository):
make install-jax-gpu # Auto-detects CUDA 12 or 13
Manual Install:
# Remove ALL existing JAX/CUDA packages first (prevents plugin conflicts)
pip uninstall -y jax jaxlib \
jax-cuda13-plugin jax-cuda13-pjrt \
jax-cuda12-plugin jax-cuda12-pjrt
# For CUDA 13.x (SM >= 7.5):
pip install "jax[cuda13-local]"
# For CUDA 12.x (SM >= 5.2):
pip install "jax[cuda12-local]"
Or with project extras:
pip install "rheojax[gpu_cuda13]" # CUDA 13
pip install "rheojax[gpu_cuda12]" # CUDA 12
Important
Never install both cuda12 and cuda13 plugins simultaneously. Only ONE CUDA plugin set can be active — having both causes PJRT registration conflicts.
Requirements:
System CUDA 12.1+ or 13.x pre-installed (not bundled)
NVIDIA driver >= 525 (CUDA 12) or >= 560 (CUDA 13)
Linux x86_64 or aarch64
Verify GPU Detection:
import jax
print(jax.devices()) # Should show [cuda(id=0)] if GPU detected
Diagnostics:
make gpu-check # Verify GPU backend, devices, SVD computation
make gpu-diagnose # Check for plugin conflicts, version mismatches
Verifying Installation¶
Verify the installation:
import rheojax
print(f"RheoJAX version: {rheojax.__version__}")
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
This should display the version numbers and available devices (CPU or GPU) without errors.
Troubleshooting¶
GPU Issues¶
Symptom |
Cause |
Fix |
|---|---|---|
|
Plugin/jaxlib version mismatch |
Uninstall all, reinstall: |
|
Both cuda12 and cuda13 plugins installed |
Uninstall all, reinstall only ONE |
|
CUDA toolkit missing or not in PATH |
|
|
GPU JAX packages not installed |
|
|
CUDA libs not in LD_LIBRARY_PATH |
|
|
Plugin/jaxlib version mismatch |
Same as first row |
Use make gpu-diagnose to automatically detect plugin conflicts and version mismatches.
JAX Installation Issues¶
If you encounter issues with JAX installation, refer to the JAX installation guide.
GPU acceleration requires Linux + CUDA 12.1+ or 13.x + NVIDIA driver >= 525 (CUDA 12) or >= 560 (CUDA 13)
macOS and Windows only support CPU mode
Import Errors¶
If you encounter import errors, ensure all dependencies are installed:
uv sync