r"""
JAX-compatible Mittag-Leffler function implementations.
This module provides efficient, JAX-compatible implementations of the Mittag-Leffler
function using a hybrid strategy:
1. Taylor series for small arguments (\|z\| < 8)
2. Asymptotic expansions for large arguments (\|z\| > 8)
- Exponential expansion for positive z (Creep mode growth)
- Inverse power law expansion for negative z (Relaxation mode decay)
This approach avoids the numerical instability of Padé approximations near alpha=beta
and correctly models the exponential growth for positive arguments.
References
----------
- R. Garrappa, Numerical evaluation of two and three parameter Mittag-Leffler
functions, SIAM Journal of Numerical Analysis, 2015, 53(3), 1350-1369
- Haubold, H. J., Mathai, A. M., & Saxena, R. K. (2011). Mittag-Leffler functions
and their applications. Journal of applied mathematics, 2011.
"""
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger
logger = get_logger(__name__)
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
jax_gamma = jax.scipy.special.gamma
# ML-CONST: Module-level constant for Taylor iteration indices.
# Hoisted out of _ml_taylor to avoid re-materializing the array on every call
# (which happens once per vmapped element when using vmap over z).
# NOTE: safe_import_jax() on line 30 ensures float64 is enabled before this
# array is created. Do not reorder imports above this line.
_ML_TAYLOR_K = jnp.arange(300, dtype=jnp.float64)
[docs]
def mittag_leffler_e(z: float | jnp.ndarray, alpha: float) -> float | jnp.ndarray:
"""
One-parameter Mittag-Leffler function E_α(z).
E_α(z) = E_{α,1}(z)
Parameters
----------
z : float or jnp.ndarray
Argument(s) of the Mittag-Leffler function.
alpha : float
Order parameter, must be real and positive (0 < alpha <= 2).
Returns
-------
float or jnp.ndarray
Value(s) of E_α(z).
"""
# Validate alpha when not traced (static values only)
if not isinstance(alpha, (jax.core.Tracer, jnp.ndarray)):
if not (0 < alpha <= 2):
logger.error(
"Invalid alpha parameter for Mittag-Leffler function",
alpha=alpha,
valid_range="(0, 2]",
)
raise ValueError(f"alpha must satisfy 0 < alpha <= 2, got alpha={alpha}")
return mittag_leffler_e2(z, alpha, beta=1.0)
[docs]
@jax.jit
def mittag_leffler_e2(
z: float | jnp.ndarray, alpha: float, beta: float
) -> float | jnp.ndarray:
r"""
Two-parameter Mittag-Leffler function E_{α,β}(z).
Uses a hybrid evaluation strategy:
- \|z\| <= 8: Taylor Series (Kahan summation)
- z > 8: Positive Asymptotic Expansion (Exponential growth)
- z < -8: Negative Asymptotic Expansion (Algebraic decay)
- Smooth blending at boundaries for gradient stability.
Parameters
----------
z : float or jnp.ndarray
Argument(s) of the Mittag-Leffler function.
alpha : float
First parameter (0 < alpha <= 2).
beta : float
Second parameter.
Returns
-------
float or jnp.ndarray
Value(s) of E_{α,β}(z).
"""
# Validate alpha when not traced (static values only)
# Note: We check against Tracer and Array to allow JIT-compiled calls to pass through
if not isinstance(alpha, (jax.core.Tracer, jnp.ndarray)):
if not (0 < alpha <= 2):
logger.error(
"Invalid alpha parameter for Mittag-Leffler function",
alpha=alpha,
beta=beta,
valid_range="(0, 2]",
)
raise ValueError(f"alpha must satisfy 0 < alpha <= 2, got alpha={alpha}")
# Convert scalar to array for consistency
z_arr = jnp.asarray(z)
is_scalar = z_arr.ndim == 0
z_arr = jnp.atleast_1d(z_arr)
# Use float64 for precision
z_f64 = (
z_arr.astype(jnp.float64)
if jnp.isrealobj(z_arr)
else z_arr.astype(jnp.complex128)
)
# Vectorized computation
result = _mittag_leffler_hybrid(z_f64, alpha, beta)
# Cast back to original dtype if needed (e.g. if input was float32)
if jnp.issubdtype(z_arr.dtype, jnp.floating):
result = result.astype(z_arr.dtype)
if is_scalar:
return result[0]
return result
def _ml_taylor(z, alpha, beta, n_iter=300):
r"""Taylor series: E_{a,b}(z) = \sum_{k=0}^{N} z^k / \Gamma(a k + b).
For real z, uses vectorized log-space computation via ``gammaln`` to ensure
clean JAX gradients (the ``fori_loop`` overflow clamp produced inf in the
unused ``jnp.where`` branch, corrupting the backward pass — see KRN-011).
For complex z, falls back to iterative Kahan summation with overflow clamp
(gradient w.r.t. complex z is not required by current use cases).
"""
if jnp.iscomplexobj(z):
return _ml_taylor_complex(z, alpha, beta, n_iter)
# --- Real z: vectorized log-space computation ---
# ML-CONST: reuse the module-level arange; slice to n_iter if caller passes
# a value smaller than 300 (rare, but preserves the original API contract).
k = _ML_TAYLOR_K if n_iter == 300 else jnp.arange(n_iter, dtype=jnp.float64)
# ML-02: Fuse intermediate allocations.
# Compute log|z^k| directly without a separate safe_abs_z variable.
abs_z = jnp.abs(z)
log_abs_z = jnp.log(jnp.maximum(abs_z, 1e-300))
# log|term_k| = k*log|z| - gammaln(a*k + b), k=0 → log_zpow=0
# Fused into a single expression: avoids separate log_zpow and log_gamma arrays.
log_abs_terms = jnp.where(k == 0, 0.0, k * log_abs_z) - jax.scipy.special.gammaln(
alpha * k + beta
)
# Clamp to avoid exp overflow, then exponentiate once.
abs_terms = jnp.exp(jnp.minimum(log_abs_terms, 700.0))
# Zero out when z ≈ 0 and k > 0 (z^k → 0, but log-space gives artefacts)
abs_terms = jnp.where((k > 0) & (abs_z < 1e-300), 0.0, abs_terms)
# R11-ML-001: Zero out negligible terms to avoid unnecessary computation
abs_terms = jnp.where(abs_terms < 1e-30 * jnp.max(abs_terms), 0.0, abs_terms)
# Sign: z^k = |z|^k for z >= 0, (-1)^k |z|^k for z < 0.
# Fused sign computation — no separate neg_sign variable.
sign = jnp.where(z >= 0, 1.0, jnp.where(k % 2 == 0, 1.0, -1.0))
return jnp.sum(sign * abs_terms)
def _ml_taylor_complex(z, alpha, beta, n_iter=300):
"""Iterative Taylor series for complex z (Kahan summation + overflow clamp)."""
def body(k, state):
sum_val, c_val, z_pow = state
term = z_pow / jax_gamma(alpha * k + beta)
# Kahan summation step
y = term - c_val
t = sum_val + y
c_new = (t - sum_val) - y
sum_new = t
# Update z_power with overflow clamp (KRN-011)
z_pow_raw = z_pow * z
abs_val = jnp.abs(z_pow_raw)
scale = jnp.where(abs_val > 1e300, 1e300 / jnp.maximum(abs_val, 1e-300), 1.0)
z_pow_new = z_pow_raw * scale
return sum_new, c_new, z_pow_new
init_state = (jnp.zeros_like(z), jnp.zeros_like(z), jnp.ones_like(z))
total, _, _ = jax.lax.fori_loop(0, n_iter, body, init_state)
return total
def _ml_asymptotic_pos(z, alpha, beta):
"""
Asymptotic expansion for large positive z (Creep mode).
E_{a,b}(z) ~ (1/a) * z^((1-b)/a) * exp(z^(1/a))
KRN-005: Uses log-space evaluation with overflow cap to prevent inf
for small alpha (< 0.5) at moderate z values.
"""
inv_alpha = 1.0 / alpha
# Compute in log-space to avoid overflow
log_exponent = inv_alpha * jnp.log(jnp.maximum(z, 1e-30))
log_power = (1.0 - beta) * inv_alpha * jnp.log(jnp.maximum(z, 1e-30))
log_prefactor = jnp.log(inv_alpha) + log_power
# Cap the total log-result at 709 (exp(709) ≈ 8.2e307, near float64 max)
log_result = log_prefactor + log_exponent
log_result = jnp.minimum(log_result, 709.0)
return jnp.exp(log_result)
def _safe_rgamma(x):
"""Compute 1/Gamma(x) safely, returning 0 at poles (negative integers).
Uses "safe-where" pattern (guarded inputs, no lax.cond) so that JAX
auto-diff produces finite gradients in BOTH branches even though only
one branch's value is selected. Ref: DLMF 5.2 — 1/Γ(z).
For x < 0.5 (reflection): 1/Γ(z) = sin(πz) · Γ(1−z) / π
Since z < 0.5 ⟹ 1−z > 0.5, Γ(1−z) has no poles.
For x ≥ 0.5 (standard): 1/Γ(z) directly, no poles for z > 0.
"""
is_reflection = x < 0.5
# --- Standard branch: 1/Gamma(x) for x >= 0.5 ---
# Guard: when reflection is active, use x=1.0 (safe) to avoid NaN grads
x_std = jnp.where(is_reflection, 1.0, x)
x_std = jnp.clip(x_std, 1e-10, 170.0)
g_std = jax_gamma(x_std)
val_std = 1.0 / jnp.maximum(g_std, 1e-300)
# --- Reflection branch: sin(πz) * Gamma(1-z) / π for x < 0.5 ---
# Guard: when standard is active, use refl_arg=1.0 (safe) to avoid NaN grads
refl_arg = jnp.where(is_reflection, 1.0 - x, 1.0)
refl_arg = jnp.clip(refl_arg, 0.5, 170.0) # 1-x > 0.5 when x < 0.5
g_refl = jax_gamma(refl_arg)
sin_val = jnp.sin(jnp.pi * jnp.where(is_reflection, x, 0.0))
val_refl = sin_val * g_refl / jnp.pi
return jnp.where(is_reflection, val_refl, val_std)
def _ml_asymptotic_neg(z, alpha, beta, n_terms=20):
"""
Asymptotic expansion for large negative z (Relaxation mode).
E_{a,b}(z) ~ - sum_{k=1}^N z^(-k) / Gamma(beta - alpha*k)
ML-03: Precompute all gamma_args as a vector, then apply vmap(_safe_rgamma)
once and do a single vectorized dot product. Eliminates the fori_loop and
reduces the number of sequential kernel dispatches from n_terms to 1.
"""
inv_z = 1.0 / z
# k = 1 .. n_terms as a static vector (shape (n_terms,))
ks = jnp.arange(1, n_terms + 1, dtype=jnp.float64)
# Precompute all gamma arguments in one shot
gamma_args = beta - alpha * ks # shape (n_terms,)
# Vectorised reciprocal-gamma over all 20 arguments at once
rgamma_vals = jax.vmap(_safe_rgamma)(gamma_args) # shape (n_terms,)
# inv_z^k = exp(k * log(inv_z)) — more numerically stable than pow iteration
# inv_z is a scalar (negative, so take absolute value first then restore sign)
# Note: z < 0 and k is integer → sign of inv_z^k follows (-1)^k
abs_inv_z = jnp.abs(inv_z)
pow_abs = jnp.exp(ks * jnp.log(jnp.maximum(abs_inv_z, 1e-300)))
inv_z_pow = jnp.where(ks % 2 == 0, pow_abs, -pow_abs) # restores sign of inv_z^k
# Signed terms and sum; series is -sum(...)
terms = inv_z_pow * rgamma_vals
return -jnp.sum(terms)
def _sigmoid_blend(x, transition, width=1.0):
"""Smooth sigmoid transition from 0 to 1 around transition point."""
return jax.nn.sigmoid((x - transition) / width)
def _smooth_blend(val1, val2, z, threshold, width=0.5):
"""
Smoothly blend between val1 (z < threshold) and val2 (z > threshold).
Parameters
----------
val1 : scalar
Value for z < threshold.
val2 : scalar
Value for z > threshold.
z : scalar
Control variable.
threshold : float
Transition point.
width : float
Width of the transition region.
Returns
-------
scalar
Blended value.
"""
weight = jax.nn.sigmoid((z - threshold) / width)
return (1.0 - weight) * val1 + weight * val2
def _mittag_leffler_hybrid(z, alpha, beta):
"""Vectorised hybrid: replaces lax.cond-inside-vmap with jnp.where.
ML-04: ``jax.lax.cond`` placed inside ``jax.vmap`` forces XLA to
evaluate *both* branches for every element sequentially (it cannot
parallelise the predicate-dependent dispatch). For N=50 points the
original scalar-kernel/vmap approach took ~150-900 ms on CPU.
The fix: hoist all branch computations to array-level (one pass each)
and select results with ``jnp.where``. XLA can then vectorise all
three branch evaluations simultaneously, reducing wall time by ~75x.
The mathematical result is identical to the original implementation:
same thresholds, same blend weights, same Taylor/asymptotic kernels.
Thresholds & Widths
-------------------
WIDTH_POS=0.2 → sigmoid leakage at z=0 < 1e-17 (positive branch)
thresh_neg → alpha-dependent; switches earlier for small alpha
cutoff_neg → 4σ from thresh_neg (pure asymptotic below this)
"""
THRESH_POS = 8.0
WIDTH_POS = 0.2
CUTOFF_POS = 10.0
# Prepare arrays (broadcast alpha/beta for the array-alpha/array-beta path)
z_arr = jnp.asarray(z)
a_arr = jnp.asarray(alpha)
b_arr = jnp.asarray(beta)
z_b, a_b, b_b = jnp.broadcast_arrays(z_arr, a_arr, b_arr)
# When alpha/beta are uniform scalars (99% of calls), extract the scalar
# to let XLA hoist the per-k gammaln computation out of the z-loop.
# When they are arrays, fall back to the per-element vmap path.
alpha_is_scalar = a_b.ndim == 0 or (a_b.ndim == 1 and a_b.size == 1)
beta_is_scalar = b_b.ndim == 0 or (b_b.ndim == 1 and b_b.size == 1)
if alpha_is_scalar and beta_is_scalar:
# --- Fast path: vectorised over z, scalar alpha/beta ---
a_s = a_b.ravel()[0] if a_b.ndim > 0 else a_b
b_s = b_b.ravel()[0] if b_b.ndim > 0 else b_b
thresh_neg = -0.9 - 7.1 * a_s
width_neg = 0.1 + 0.4 * a_s
cutoff_neg = thresh_neg - 4.0 * width_neg
# Taylor — computed once for all z (gammaln is hoisted by XLA)
val_taylor = jax.vmap(lambda zi: _ml_taylor(zi, a_s, b_s))(z_b)
# Asymptotic branches with guarded inputs
z_pos_safe = jnp.maximum(z_b, 1.0)
z_neg_safe = jnp.minimum(z_b, thresh_neg)
val_pos = jax.vmap(lambda zi: _ml_asymptotic_pos(zi, a_s, b_s))(z_pos_safe)
val_neg_raw = jax.vmap(lambda zi: _ml_asymptotic_neg(zi, a_s, b_s))(z_neg_safe)
# For alpha >= 1 the neg asymptotic diverges; fall back to Taylor
val_neg = jnp.where(a_s < 1.0, val_neg_raw, val_taylor)
# Blend neg <-> taylor (left transition)
w_neg = jax.nn.sigmoid((z_b - thresh_neg) / width_neg)
blended = (1.0 - w_neg) * val_neg + w_neg * val_taylor
# Blend blended <-> pos (right transition)
w_pos = jax.nn.sigmoid((z_b - THRESH_POS) / WIDTH_POS)
result = (1.0 - w_pos) * blended + w_pos * val_pos
# Override with pure branches outside blend regions
result = jnp.where(z_b > CUTOFF_POS, val_pos, result)
result = jnp.where(z_b < cutoff_neg, val_neg, result)
return result
# --- Slow path: array alpha/beta — retain original per-element kernel ---
# This path is rarely exercised (only when alpha/beta are data arrays).
def _kernel(z_val, a_val, b_val):
thresh_neg = -0.9 - 7.1 * a_val
width_neg = 0.1 + 0.4 * a_val
cutoff_neg = thresh_neg - 4.0 * width_neg
def _pure_pos(_):
return _ml_asymptotic_pos(z_val, a_val, b_val)
def _pure_neg(_):
val = _ml_asymptotic_neg(z_val, a_val, b_val)
val_taylor = _ml_taylor(z_val, a_val, b_val, n_iter=300)
return jnp.where(a_val < 1.0, val, val_taylor)
def _blended_region(_):
val_taylor = _ml_taylor(z_val, a_val, b_val, n_iter=300)
z_pos_safe = jnp.maximum(z_val, 1.0)
val_pos = _ml_asymptotic_pos(z_pos_safe, a_val, b_val)
z_neg_safe = jnp.minimum(z_val, thresh_neg)
val_neg_raw = _ml_asymptotic_neg(z_neg_safe, a_val, b_val)
val_neg = jnp.where(a_val < 1.0, val_neg_raw, val_taylor)
res = _smooth_blend(val_neg, val_taylor, z_val, thresh_neg, width_neg)
res = _smooth_blend(res, val_pos, z_val, THRESH_POS, WIDTH_POS)
return res
return jax.lax.cond(
z_val > CUTOFF_POS,
_pure_pos,
lambda _: jax.lax.cond(
z_val < cutoff_neg, _pure_neg, _blended_region, operand=None
),
operand=None,
)
return jax.vmap(_kernel)(z_b, a_b, b_b)
# Convenience aliases
ml_e = mittag_leffler_e
ml_e2 = mittag_leffler_e2
__all__ = [
"mittag_leffler_e",
"mittag_leffler_e2",
"ml_e",
"ml_e2",
]