"""Tensorial Elasto-Plastic Model (EPM) implementation.
This module implements the full tensorial (3-component) stress formulation for EPM
simulations. It tracks the stress tensor [σ_xx, σ_yy, σ_xy] in 2D plane strain,
enabling prediction of normal stress differences (N₁, N₂), anisotropic flow behavior,
and kinematic hardening.
Key Features:
- Full tensorial stress state per lattice site
- Von Mises and Hill anisotropic yield criteria
- Normal stress difference predictions (N₁, N₂)
- Flexible fitting: shear-only or combined [σ_xy, N₁]
"""
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.epm.base import EPMBase
from rheojax.utils.epm_kernels_tensorial import (
make_tensorial_propagator_q,
tensorial_epm_step,
)
jax, jnp = safe_import_jax()
[docs]
@ModelRegistry.register(
"tensorial_epm",
protocols=[
Protocol.FLOW_CURVE,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class TensorialEPM(EPMBase):
"""3-Component Tensorial Lattice EPM.
A mesoscopic model for amorphous solids that explicitly tracks the full stress
tensor, enabling predictions of normal stress differences and anisotropic flow.
Physics:
- Lattice of elastoplastic blocks with tensorial stress state.
- Elastic loading (affine) for all stress components.
- Von Mises or Hill yield criterion for anisotropic materials.
- Component-wise plastic flow rule (Prandtl-Reuss).
- Tensorial Eshelby propagator for stress redistribution.
Parameters:
mu (float): Shear modulus. Default 1.0.
nu (float): Poisson's ratio for plane strain. Default 0.48 (avoid 0.5 singularity).
tau_pl (float): Base plastic relaxation timescale (legacy). Default 1.0.
tau_pl_shear (float): Plastic relaxation time for shear. Default 1.0.
tau_pl_normal (float): Plastic relaxation time for normal stresses. Default 1.0.
sigma_c_mean (float): Mean yield threshold. Default 1.0.
sigma_c_std (float): Disorder strength (std dev of thresholds). Default 0.1.
smoothing_width (float): Width for smooth yielding approx (inference only). Default 0.1.
w_N1 (float): Weight for N₁ in combined fitting loss. Default 1.0.
hill_H (float): Hill anisotropy parameter H. Default 0.5.
hill_N (float): Hill anisotropy parameter N. Default 1.5.
Configuration:
L (int): Lattice size (LxL). Default 64.
dt (float): Time step. Default 0.01.
yield_criterion (str): "von_mises" or "hill". Default "von_mises".
"""
[docs]
def __init__(
self,
L: int = 64,
dt: float = 0.01,
mu: float = 1.0,
nu: float = 0.48, # Avoid nu=0.5 (incompressible singularity in plane strain)
tau_pl: float = 1.0,
tau_pl_shear: float = 1.0,
tau_pl_normal: float = 1.0,
sigma_c_mean: float = 1.0,
sigma_c_std: float = 0.1,
n_fluid: float = 1.0,
yield_criterion: str = "von_mises",
n_bayesian_steps: int = 200,
fluidity_form: str = "overstress",
):
"""Initialize the Tensorial EPM.
Args:
L: Lattice size (LxL grid).
dt: Time step for integration.
mu: Shear modulus.
nu: Poisson's ratio for plane strain constraint.
tau_pl: Base plastic relaxation time (for compatibility).
tau_pl_shear: Plastic relaxation time for shear components.
tau_pl_normal: Plastic relaxation time for normal stress components.
sigma_c_mean: Mean yield threshold.
sigma_c_std: Standard deviation of yield thresholds (disorder).
yield_criterion: Yield criterion name ("von_mises" or "hill").
n_bayesian_steps: Number of time steps for Bayesian inference. Default 200.
"""
# Initialize base class with common parameters.
# TensorialEPM uses the Prandtl-Reuss flow rule via its own kernel
# (``rheojax.utils.epm_kernels_tensorial``). The kernel now supports
# all three fluidity forms — see `compute_plastic_strain_rate` for
# the mathematical definitions. "overstress" is the default and
# recommended choice for yield-stress fluids (emulsions, gels, foams,
# pastes); it produces a full Herschel-Bulkley flow curve with the
# von-Mises pure-shear plateau at sigma_c_mean / sqrt(3).
if fluidity_form not in ("linear", "power", "overstress"):
raise ValueError(
f"fluidity_form must be 'linear', 'power', or 'overstress'; "
f"got {fluidity_form!r}."
)
super().__init__(
L=L,
dt=dt,
mu=mu,
tau_pl=tau_pl,
sigma_c_mean=sigma_c_mean,
sigma_c_std=sigma_c_std,
n_fluid=n_fluid,
n_bayesian_steps=n_bayesian_steps,
fluidity_form=fluidity_form,
)
# Add tensorial-specific parameters
self.parameters.add(
"nu",
nu,
bounds=(0.3, 0.5),
units="",
description="Poisson's ratio for plane strain",
)
self.parameters.add(
"tau_pl_shear",
tau_pl_shear,
bounds=(0.01, 100.0),
units="s",
description="Plastic relaxation time for shear",
)
self.parameters.add(
"tau_pl_normal",
tau_pl_normal,
bounds=(0.01, 100.0),
units="s",
description="Plastic relaxation time for normal stresses",
)
self.parameters.add(
"w_N1",
1.0,
bounds=(0.1, 10.0),
units="",
description="Weight for N₁ in combined fitting loss",
)
self.parameters.add(
"hill_H",
0.5,
bounds=(0.1, 5.0),
units="",
description="Hill anisotropy parameter H",
)
self.parameters.add(
"hill_N",
1.5,
bounds=(0.1, 5.0),
units="",
description="Hill anisotropy parameter N",
)
# Yield criterion (static configuration)
if yield_criterion not in ["von_mises", "hill"]:
raise ValueError(
f"Unknown yield criterion: {yield_criterion}. "
"Must be 'von_mises' or 'hill'."
)
self.yield_criterion = yield_criterion
# Precompute tensorial propagator (cached)
# Using mu=1.0 as normalization, will scale by actual mu during execution
self._propagator_q_norm = make_tensorial_propagator_q(L, nu=nu, mu=1.0)
def _init_stress(self, key: jax.Array) -> jax.Array:
"""Initialize tensorial stress field.
Args:
key: PRNG key (unused for zero initialization).
Returns:
Zero-initialized stress tensor of shape (3, L, L) for [σ_xx, σ_yy, σ_xy].
"""
# Start relaxed (zero stress)
return jnp.zeros((3, self.L, self.L))
def _is_scalar_epm(self) -> bool:
"""Check if this is a scalar (not tensorial) EPM.
Returns False for TensorialEPM (tensorial stress field).
This causes model_function to use general methods instead of scalar JIT kernels.
"""
return False
def _get_param_dict(self) -> dict[str, float]:
"""Extract parameters as dictionary for kernel calls.
Extends base class method to include tensorial parameters.
Returns:
Dictionary with all EPM parameters including tensorial ones.
"""
base_params = super()._get_param_dict()
# Add tensorial-specific parameters
nu = self.parameters.get_value("nu")
tau_pl_shear = self.parameters.get_value("tau_pl_shear")
tau_pl_normal = self.parameters.get_value("tau_pl_normal")
hill_H = self.parameters.get_value("hill_H")
hill_N = self.parameters.get_value("hill_N")
if nu is None:
raise ValueError("Parameter 'nu' must be set before use")
if tau_pl_shear is None:
raise ValueError("Parameter 'tau_pl_shear' must be set before use")
if tau_pl_normal is None:
raise ValueError("Parameter 'tau_pl_normal' must be set before use")
if hill_H is None:
raise ValueError("Parameter 'hill_H' must be set before use")
if hill_N is None:
raise ValueError("Parameter 'hill_N' must be set before use")
tensorial_params = {
"nu": nu,
"tau_pl_shear": tau_pl_shear,
"tau_pl_normal": tau_pl_normal,
"hill_H": hill_H,
"hill_N": hill_N,
}
return {**base_params, **tensorial_params}
def _epm_step(
self,
state: tuple[jax.Array, jax.Array, float, jax.Array],
propagator_q: jax.Array,
shear_rate: float,
dt: float,
params: dict,
smooth: bool,
) -> tuple[jax.Array, jax.Array, float, jax.Array]:
"""Perform one tensorial EPM time step.
Delegates to tensorial_epm_step kernel from epm_kernels_tensorial module.
Args:
state: Current state (stress, thresholds, strain, key).
propagator_q: Precomputed tensorial propagator.
shear_rate: Imposed macroscopic shear rate.
dt: Time step size.
params: Model parameters dictionary.
smooth: Use smooth yielding (tanh) or hard threshold (step).
Returns:
Updated state tuple (new_stress, thresholds, new_strain, key).
"""
stress, thresholds, strain, key = state
# Call tensorial kernel — forward the model's fluidity_form so that
# NLSQ fitting, NUTS inference, and forward .predict() all use the
# same constitutive law.
new_stress = tensorial_epm_step(
stress=stress,
thresholds=thresholds,
strain_rate=shear_rate,
dt=dt,
propagator=propagator_q,
params=params,
smooth=smooth,
yield_criterion=self.yield_criterion,
fluidity_form=self.fluidity_form,
)
# Update accumulated strain
new_strain = strain + shear_rate * dt
return (new_stress, thresholds, new_strain, key)
[docs]
def get_shear_stress(self, stress: jax.Array) -> jax.Array:
"""Extract shear stress component from stress tensor.
Args:
stress: Stress tensor of shape (3, L, L) or (..., 3, L, L).
Returns:
Shear stress σ_xy, shape (L, L) or (..., L, L).
"""
# Component ordering: [σ_xx, σ_yy, σ_xy]
if stress.ndim == 3:
return stress[2]
else:
return stress[..., 2, :, :]
[docs]
def get_normal_stress_differences(
self, stress: jax.Array, nu: float | None = None
) -> tuple[jax.Array, jax.Array]:
"""Compute normal stress differences from stress tensor.
For plane strain: σ_zz = ν(σ_xx + σ_yy)
Normal stress differences:
- N₁ = σ_xx - σ_yy
- N₂ = σ_yy - σ_zz
Args:
stress: Stress tensor of shape (3, L, L).
nu: Poisson's ratio. If None, uses parameter value.
Returns:
Tuple (N₁, N₂), each of shape (L, L).
"""
if nu is None:
nu = self.parameters.get_value("nu")
sigma_xx = stress[0]
sigma_yy = stress[1]
# Plane strain constraint
sigma_zz = nu * (sigma_xx + sigma_yy)
N1 = sigma_xx - sigma_yy
N2 = sigma_yy - sigma_zz
return N1, N2
[docs]
def predict_normal_stresses(
self, data: RheoData, **kwargs
) -> tuple[jax.Array, jax.Array]:
"""Convenience method to predict normal stress differences.
Runs the simulation and extracts N₁ and N₂ spatial averages over time.
Args:
data: RheoData with protocol specification.
**kwargs: Additional arguments passed to predict().
Returns:
Tuple (N₁_array, N₂_array) with time-averaged values.
Raises:
NotImplementedError: Not yet implemented (future feature).
"""
raise NotImplementedError(
"predict_normal_stresses() not yet implemented. "
"Use predict() with test_mode='flow_curve' which returns [σ_xy, N₁]."
)
# Override flow_curve to return shear stress with N₁ in metadata
def _run_flow_curve(
self,
data: RheoData,
key: jax.Array,
propagator_q: jax.Array,
params: dict,
smooth: bool,
) -> RheoData:
"""Steady state flow curve: Stress vs Shear Rate.
For tensorial EPM, returns shear stress σ_xy with N₁ stored in metadata.
Args:
data: RheoData with x=shear_rates.
key: PRNG key.
propagator_q: Precomputed propagator.
params: Model parameters.
smooth: Use smooth yielding.
Returns:
RheoData with x=shear_rates, y=sigma_xy.
metadata contains 'N1' with first normal stress differences.
"""
shear_rates = data.x
sigma_c_mean = params["sigma_c_mean"]
tau_pl_shear = params.get("tau_pl_shear", params.get("tau_pl", 1.0))
n_fluid = params.get("n_fluid", 1.0)
def scan_fn(gdot):
# Run simulation for sufficient steps to reach steady state
n_steps = 1000
state = self._init_state(key)
stress0, thresholds, strain, k = state
# Warm-start sigma_xy at the analytical overstress steady state
# for pure shear with von Mises yielding.
#
# The tensorial stress update is dσ_xy/dt = μγ̇ − 2μ·ε̇^p_xy (see the
# Budrikis & Zapperi 2013 reference at the top of
# epm_kernels_tensorial.py), so at steady state ε̇^p_xy = γ̇/2 in the
# tensor convention. Solving the overstress Prandtl-Reuss rule for
# pure shear gives:
#
# σ_xy = σ_c_mean/√3
# + (1/√3) · (√3 / 2)^(1/n_fluid)
# · σ_c_mean^((n_fluid − 1)/n_fluid)
# · (|γ̇|·τ_pl_shear)^(1/n_fluid)
#
# Special cases:
# n_fluid = 1 → σ_xy = σ_c_mean/√3 + |γ̇|·τ_pl_shear / 2
# n_fluid = 2 → σ_xy = σ_c_mean/√3
# + √( σ_c_mean · |γ̇| · τ_pl_shear / (2·√3) )
#
# Starting near this target avoids long transient loading times at
# low shear rates and destabilising FFT redistribution at high
# rates. Normal components (indices 0, 1) are left at zero — they
# develop self-consistently from the Prandtl-Reuss flow.
sqrt3 = jnp.sqrt(3.0)
gdot_abs = jnp.abs(gdot)
# Guard against pathological parameter combinations that NLSQ can
# transiently probe during fitting (very small sigma_c_mean or
# n_fluid, extreme shear rates). The clamps below keep the
# analytical warm-start formula numerically well-defined without
# affecting its value at reasonable parameters.
scm_safe = jnp.maximum(sigma_c_mean, 1e-6)
n_safe = jnp.maximum(n_fluid, 1e-3)
tau_safe = jnp.maximum(tau_pl_shear, 1e-6)
inv_n = 1.0 / n_safe
excess_base = jnp.maximum(gdot_abs * tau_safe, 0.0)
warm_excess = (
(1.0 / sqrt3)
* (sqrt3 / 2.0) ** inv_n
* scm_safe ** ((n_safe - 1.0) * inv_n)
* excess_base**inv_n
)
# Final clamp: if any component is NaN/inf, fall back to the
# plateau value scm_safe / sqrt(3) so the simulation still runs.
sigma_xy_warm_raw = jnp.sign(gdot) * (scm_safe / sqrt3 + warm_excess)
sigma_xy_warm = jnp.where(
jnp.isfinite(sigma_xy_warm_raw),
sigma_xy_warm_raw,
jnp.sign(gdot) * scm_safe / sqrt3,
)
stress0 = stress0.at[2].add(sigma_xy_warm)
state = (stress0, thresholds, strain, k)
def body(carrier, _):
curr_state = carrier
new_state = self._epm_step(
curr_state, propagator_q, gdot, self.dt, params, smooth
)
stress_tensor = new_state[0] # Shape (3, L, L)
# Extract shear stress and N₁
sigma_xy_mean = jnp.mean(stress_tensor[2])
N1, _ = self.get_normal_stress_differences(stress_tensor, params["nu"])
N1_mean = jnp.mean(N1)
return new_state, jnp.array([sigma_xy_mean, N1_mean])
_, history = jax.lax.scan(body, state, None, length=n_steps)
# Average last 50% for steady state
# history has shape (n_steps, 2)
steady_values = jnp.mean(history[n_steps // 2 :], axis=0)
return steady_values # [σ_xy, N₁]
# Vectorize over shear rates
stresses = jax.vmap(scan_fn)(shear_rates) # Shape: (n_rates, 2)
# Extract components with NaN safety net. NLSQ fitting can transiently
# probe parameter combinations that make the explicit-Euler kernel
# unstable (e.g., very small sigma_c_mean with large n_fluid at high
# gdot), producing NaN in the steady-state average. Replace those
# entries with the analytical plateau sigma_c_mean/sqrt(3) so the
# validation downstream does not explode — the optimiser will
# naturally move away from NaN regions in the next iteration.
sigma_xy_raw = stresses[:, 0]
N1_raw = stresses[:, 1]
scm_safe = jnp.maximum(params.get("sigma_c_mean", 1.0), 1e-6)
plateau_fallback = scm_safe / jnp.sqrt(3.0)
sigma_xy = jnp.where(jnp.isfinite(sigma_xy_raw), sigma_xy_raw, plateau_fallback)
N1 = jnp.where(jnp.isfinite(N1_raw), N1_raw, 0.0)
# Store N₁ in metadata
result_metadata = data.metadata.copy() if data.metadata else {}
result_metadata["N1"] = N1
return RheoData(
x=shear_rates,
y=sigma_xy,
initial_test_mode="flow_curve",
metadata=result_metadata,
)
def _predict(self, X, **kwargs) -> RheoData:
"""Simulate the model for the given protocol.
Args:
X: Input data - can be RheoData or numpy/JAX array.
kwargs:
test_mode (str): 'flow_curve', 'startup', 'relaxation', 'creep', 'oscillation'.
smooth (bool): Use smooth yielding (default False for simulation, True for fitting).
seed (int): Random seed (default 0).
Returns:
RheoData with simulation results.
- flow_curve: y is σ_xy array, metadata['N1'] contains N₁ values
- Other protocols: y has shape (n_points,) with σ_xy only
"""
# Handle both RheoData and raw array input
if isinstance(X, RheoData):
rheo_data = X
test_mode = kwargs.get("test_mode", rheo_data.test_mode)
else:
# Raw array input - wrap in RheoData
test_mode = kwargs.get("test_mode")
if test_mode is None:
test_mode = getattr(self, "_test_mode", "flow_curve")
x_array = jnp.asarray(X, dtype=jnp.float64)
# Create dummy y for RheoData constructor
dummy_y = jnp.zeros_like(x_array)
metadata = {}
# Copy cached metadata
if hasattr(self, "_cached_gamma_dot"):
metadata["gamma_dot"] = self._cached_gamma_dot
if hasattr(self, "_cached_gamma"):
metadata["gamma"] = self._cached_gamma
if hasattr(self, "_cached_stress"):
metadata["stress"] = self._cached_stress
if hasattr(self, "_cached_gamma0"):
metadata["gamma0"] = self._cached_gamma0
if hasattr(self, "_cached_omega"):
metadata["omega"] = self._cached_omega
rheo_data = RheoData(
x=x_array, y=dummy_y, initial_test_mode=test_mode, metadata=metadata
)
smooth = kwargs.get("smooth", False)
seed = kwargs.get("seed", 0)
key = jax.random.PRNGKey(seed)
# Extract parameters
# Scale propagator by current mu
mu = self.parameters.get_value("mu")
propagator_q = self._propagator_q_norm * mu
# Get full parameter dictionary
param_dict = self._get_param_dict()
if test_mode == "flow_curve":
return self._run_flow_curve(
rheo_data, key, propagator_q, param_dict, smooth
)
elif test_mode == "startup":
return self._run_startup_tensorial(
rheo_data, key, propagator_q, param_dict, smooth
)
elif test_mode == "relaxation":
return self._run_relaxation_tensorial(
rheo_data, key, propagator_q, param_dict, smooth
)
elif test_mode == "creep":
return self._run_creep_tensorial(
rheo_data, key, propagator_q, param_dict, smooth
)
elif test_mode == "oscillation":
return self._run_oscillation_tensorial(
rheo_data, key, propagator_q, param_dict, smooth
)
else:
raise ValueError(f"Unknown test_mode: {test_mode}")
def _run_startup_tensorial(
self,
data: RheoData,
key: jax.Array,
propagator_q: jax.Array,
params: dict,
smooth: bool,
) -> RheoData:
"""Start-up shear: Stress(t) at constant rate.
Extracts shear component from tensorial stress.
"""
time = data.x
if time is None:
raise ValueError("data.x (time array) must not be None")
# Calculate dt from data if possible
dt = self.dt
if len(time) > 1:
dt = float(time[1] - time[0])
# Constant shear rate from metadata
gdot = data.metadata.get("gamma_dot", 0.1)
# Scan for N-1 steps
n_steps = max(0, len(time) - 1)
state = self._init_state(key)
def body(carrier, _):
curr_state = carrier
new_state = self._epm_step(
curr_state, propagator_q, gdot, dt, params, smooth
)
# Extract shear stress component (index 2)
return new_state, jnp.mean(new_state[0][2])
if n_steps > 0:
_, stresses_scan = jax.lax.scan(body, state, None, length=n_steps)
# Prepend initial stress
initial_stress = jnp.mean(state[0][2])
stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan])
else:
stresses = jnp.array([jnp.mean(state[0][2])])
return RheoData(x=time, y=stresses, initial_test_mode="startup")
def _run_relaxation_tensorial(
self,
data: RheoData,
key: jax.Array,
propagator_q: jax.Array,
params: dict,
smooth: bool,
) -> RheoData:
"""Stress relaxation: G(t) after step strain.
Extracts shear component from tensorial stress.
"""
time = data.x
if time is None:
raise ValueError("data.x (time array) must not be None")
# Calculate dt from data
dt = self.dt
if len(time) > 1:
dt = float(time[1] - time[0])
# Step strain magnitude from metadata
strain_step = data.metadata.get("gamma", 0.1)
state = self._init_state(key)
stress, thresh, strain, k = state
# Apply Step Strain (Elastic Load) - only to shear component
mu = params["mu"]
stress = stress.at[2].set(stress[2] + mu * strain_step)
state = (stress, thresh, strain + strain_step, k)
# Initial G(0)
g_0 = jnp.mean(stress[2]) / strain_step
# Relax (gdot = 0) for N-1 steps
n_steps = max(0, len(time) - 1)
def body(carrier, _):
curr_state = carrier
new_state = self._epm_step(
curr_state, propagator_q, 0.0, dt, params, smooth
)
# Return G(t) = Shear Stress / gamma_0
return new_state, jnp.mean(new_state[0][2]) / strain_step
if n_steps > 0:
_, moduli_scan = jax.lax.scan(body, state, None, length=n_steps)
moduli = jnp.concatenate([jnp.array([g_0]), moduli_scan])
else:
moduli = jnp.array([g_0])
return RheoData(x=time, y=moduli, initial_test_mode="relaxation")
def _run_creep_tensorial(
self,
data: RheoData,
key: jax.Array,
propagator_q: jax.Array,
params: dict,
smooth: bool,
) -> RheoData:
"""Creep: Strain(t) at constant stress using Adaptive P-Controller.
Extracts shear component from tensorial stress.
"""
time = data.x
if time is None:
raise ValueError("data.x (time array) must not be None")
# Calculate dt from data
dt = self.dt
if len(time) > 1:
dt = float(time[1] - time[0])
# Target stress: metadata is canonical; fall back to mean(y) only when
# the caller passed y=constant (legacy pattern). See the matching
# scalar fix in rheojax/models/epm/base.py::_run_creep.
target_stress = data.metadata.get("stress") if data.metadata else None
if target_stress is None:
if data.y is not None and data.y.size > 0:
y_mean = float(jnp.mean(data.y))
target_stress = y_mean if abs(y_mean) > 1e-12 else 1.0
else:
target_stress = 1.0
target_stress = float(target_stress)
# Controller Params
Kp_base = 0.01
alpha = 10.0
state = self._init_state(key)
# Augmented state: (EPM_State, current_gdot)
aug_state = (state, 0.0)
# Initial strain (0.0)
initial_strain = state[2]
n_steps = max(0, len(time) - 1)
def body(carrier, _):
curr_epm, gdot = carrier
stress_grid = curr_epm[0]
# Extract shear stress component
curr_stress = jnp.mean(stress_grid[2])
# Adaptive Control
error = target_stress - curr_stress
# Gain scheduling: Boost gain if error is large relative to target
rel_error = jnp.abs(error) / (jnp.abs(target_stress) + 1e-6)
Kp = Kp_base * (1.0 + alpha * rel_error)
# Update shear rate (P-control on rate)
gdot_new = gdot + Kp * error
# Prevent negative shear rate
gdot_new = jnp.maximum(gdot_new, 0.0)
# Step EPM
new_epm = self._epm_step(
curr_epm, propagator_q, gdot_new, dt, params, smooth
)
# Return Strain
return (new_epm, gdot_new), new_epm[2]
if n_steps > 0:
_, strains_scan = jax.lax.scan(body, aug_state, None, length=n_steps)
strains = jnp.concatenate([jnp.array([initial_strain]), strains_scan])
else:
strains = jnp.array([initial_strain])
return RheoData(x=time, y=strains, initial_test_mode="creep")
def _run_oscillation_tensorial(
self,
data: RheoData,
key: jax.Array,
propagator_q: jax.Array,
params: dict,
smooth: bool,
) -> RheoData:
"""SAOS/LAOS: Stress(t) for sinusoidal strain.
Extracts shear component from tensorial stress.
"""
time = data.x
if time is None:
raise ValueError("data.x (time array) must not be None")
# Calculate dt from data
dt = self.dt
if len(time) > 1:
dt = float(time[1] - time[0])
# Params
gamma0 = data.metadata.get("gamma0", 1.0)
omega = data.metadata.get("omega", 1.0)
state = self._init_state(key)
# Initial stress
initial_stress = jnp.mean(state[0][2])
# Run for N-1 steps
n_steps = max(0, len(time) - 1)
scan_time = time[:-1] if n_steps > 0 else jnp.array([])
def body(carrier, t):
curr_state = carrier
# Time varying shear rate at current time t
gdot = gamma0 * omega * jnp.cos(omega * t)
new_state = self._epm_step(
curr_state, propagator_q, gdot, dt, params, smooth
)
# Extract shear stress component
return new_state, jnp.mean(new_state[0][2])
if n_steps > 0:
_, stresses_scan = jax.lax.scan(body, state, scan_time, length=n_steps)
stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan])
else:
stresses = jnp.array([initial_stress])
return RheoData(x=time, y=stresses, initial_test_mode="oscillation")
def _fit(self, X, y, **kwargs):
"""Fit tensorial-EPM parameters to shear-stress data.
Currently supports **shear-only fitting** (y is 1D, matching the mean
σ_xy(γ̇) for flow curves or σ_xy(t) for time-domain protocols). Under
the hood this delegates to ``EPMBase._fit`` which runs NLSQ against
the JAX-pure ``_model_*`` methods; the base-class ``_model_flow_curve``
and sister methods extract σ_xy from the tensorial (3, L, L) stress
field via ``_mean_shear_stress`` so the model function returns a 1D
array of shear stresses compatible with the user's 1D y.
Joint fitting of [σ_xy, N₁] (combined mode with w_N1 weighting) is
not yet supported — reshape to shear-only if your data includes N₁.
Args:
X: Shear rates (flow curve) or time array (time-domain protocols).
y: 1D target data (mean σ_xy or modulus).
**kwargs: Forwarded to ``EPMBase._fit``. See its docstring for
supported options (``test_mode``, ``use_log_residuals``,
``max_iter``, ``ftol``, ``xtol``, protocol kwargs).
Returns:
self for method chaining.
"""
# Delegate to EPMBase._fit. The base fit path uses self.model_function
# which calls _model_function_general for tensorial (because
# _is_scalar_epm() returns False), which in turn calls the generic
# _model_flow_curve / _model_startup / ... in base.py. Those methods
# now branch on stress.ndim to extract σ_xy correctly from the
# tensorial (3, L, L) stress field.
import jax.numpy as _jnp
y_arr = _jnp.asarray(y)
if y_arr.ndim > 1:
raise NotImplementedError(
f"TensorialEPM currently supports shear-only fitting "
f"(1D y of mean σ_xy). Got y with shape {y_arr.shape}. "
f"For combined [σ_xy, N₁] fitting, provide just the σ_xy "
f"row and use .predict() afterwards to inspect N₁."
)
return super()._fit(X, y_arr, **kwargs)