Installation

Requirements

  • Python 3.12 or higher (3.8-3.11 NOT supported)

  • JAX >= 0.8.3

  • jaxlib >= 0.8.3 (must be compatible with JAX version)

  • NumPy >= 2.3.5

  • SciPy >= 1.17.0

  • NLSQ >= 0.6.10 (GPU-accelerated optimization)

  • NumPyro >= 0.20.0 (Bayesian inference)

  • ArviZ >= 0.23.4 (Bayesian visualization)

Basic Installation

Install from PyPI:

pip install rheojax

Development Installation

For development, clone the repository and install with uv:

git clone https://github.com/imewei/rheojax.git
cd rheojax
uv sync              # Install all dependencies from uv.lock
pre-commit install   # Set up pre-commit hooks

Optional Dependencies

GPU Support (Linux + CUDA 12+ or 13+)

GPU acceleration provides 20-100x speedup for large datasets.

Platform Requirements:

  • ✅ Linux + NVIDIA GPU + CUDA 12.1+: Full GPU acceleration

  • ✅ Linux + NVIDIA GPU + CUDA 13.x: Full GPU acceleration with latest toolkit

  • ❌ macOS: CPU-only (Apple Silicon/Intel, no NVIDIA GPU support)

  • ❌ Windows: CPU-only (CUDA support experimental/unstable)

Quick Install (from repository):

make install-jax-gpu  # Auto-detects CUDA 12 or 13

Manual Install:

# Remove ALL existing JAX/CUDA packages first (prevents plugin conflicts)
pip uninstall -y jax jaxlib \
    jax-cuda13-plugin jax-cuda13-pjrt \
    jax-cuda12-plugin jax-cuda12-pjrt

# For CUDA 13.x (SM >= 7.5):
pip install "jax[cuda13-local]"

# For CUDA 12.x (SM >= 5.2):
pip install "jax[cuda12-local]"

Or with project extras:

pip install "rheojax[gpu_cuda13]"   # CUDA 13
pip install "rheojax[gpu_cuda12]"   # CUDA 12

Important

Never install both cuda12 and cuda13 plugins simultaneously. Only ONE CUDA plugin set can be active — having both causes PJRT registration conflicts.

Requirements:

  • System CUDA 12.1+ or 13.x pre-installed (not bundled)

  • NVIDIA driver >= 525 (CUDA 12) or >= 560 (CUDA 13)

  • Linux x86_64 or aarch64

Verify GPU Detection:

import jax
print(jax.devices())  # Should show [cuda(id=0)] if GPU detected

Diagnostics:

make gpu-check      # Verify GPU backend, devices, SVD computation
make gpu-diagnose   # Check for plugin conflicts, version mismatches

Verifying Installation

Verify the installation:

import rheojax
print(f"RheoJAX version: {rheojax.__version__}")

import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

This should display the version numbers and available devices (CPU or GPU) without errors.

Troubleshooting

GPU Issues

Symptom

Cause

Fix

plugin version X is not compatible with jaxlib Y

Plugin/jaxlib version mismatch

Uninstall all, reinstall: make install-jax-gpu

PJRT_Api already exists for device type cuda

Both cuda12 and cuda13 plugins installed

Uninstall all, reinstall only ONE

nvcc not found

CUDA toolkit missing or not in PATH

sudo apt install nvidia-cuda-toolkit or export PATH=/usr/local/cuda/bin:$PATH

Backend: cpu (GPU exists)

GPU JAX packages not installed

make install-jax-gpu

libcuda.so not found

CUDA libs not in LD_LIBRARY_PATH

export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

JaxRuntimeError: NOT_FOUND: cusolver_*

Plugin/jaxlib version mismatch

Same as first row

Use make gpu-diagnose to automatically detect plugin conflicts and version mismatches.

JAX Installation Issues

If you encounter issues with JAX installation, refer to the JAX installation guide.

  • GPU acceleration requires Linux + CUDA 12.1+ or 13.x + NVIDIA driver >= 525 (CUDA 12) or >= 560 (CUDA 13)

  • macOS and Windows only support CPU mode

Import Errors

If you encounter import errors, ensure all dependencies are installed:

uv sync