Source code for rheojax.models.epm.tensor

"""Tensorial Elasto-Plastic Model (EPM) implementation.

This module implements the full tensorial (3-component) stress formulation for EPM
simulations. It tracks the stress tensor [σ_xx, σ_yy, σ_xy] in 2D plane strain,
enabling prediction of normal stress differences (N₁, N₂), anisotropic flow behavior,
and kinematic hardening.

Key Features:
- Full tensorial stress state per lattice site
- Von Mises and Hill anisotropic yield criteria
- Normal stress difference predictions (N₁, N₂)
- Flexible fitting: shear-only or combined [σ_xy, N₁]
"""

from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.epm.base import EPMBase
from rheojax.utils.epm_kernels_tensorial import (
    make_tensorial_propagator_q,
    tensorial_epm_step,
)

jax, jnp = safe_import_jax()


[docs] @ModelRegistry.register( "tensorial_epm", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class TensorialEPM(EPMBase): """3-Component Tensorial Lattice EPM. A mesoscopic model for amorphous solids that explicitly tracks the full stress tensor, enabling predictions of normal stress differences and anisotropic flow. Physics: - Lattice of elastoplastic blocks with tensorial stress state. - Elastic loading (affine) for all stress components. - Von Mises or Hill yield criterion for anisotropic materials. - Component-wise plastic flow rule (Prandtl-Reuss). - Tensorial Eshelby propagator for stress redistribution. Parameters: mu (float): Shear modulus. Default 1.0. nu (float): Poisson's ratio for plane strain. Default 0.48 (avoid 0.5 singularity). tau_pl (float): Base plastic relaxation timescale (legacy). Default 1.0. tau_pl_shear (float): Plastic relaxation time for shear. Default 1.0. tau_pl_normal (float): Plastic relaxation time for normal stresses. Default 1.0. sigma_c_mean (float): Mean yield threshold. Default 1.0. sigma_c_std (float): Disorder strength (std dev of thresholds). Default 0.1. smoothing_width (float): Width for smooth yielding approx (inference only). Default 0.1. w_N1 (float): Weight for N₁ in combined fitting loss. Default 1.0. hill_H (float): Hill anisotropy parameter H. Default 0.5. hill_N (float): Hill anisotropy parameter N. Default 1.5. Configuration: L (int): Lattice size (LxL). Default 64. dt (float): Time step. Default 0.01. yield_criterion (str): "von_mises" or "hill". Default "von_mises". """
[docs] def __init__( self, L: int = 64, dt: float = 0.01, mu: float = 1.0, nu: float = 0.48, # Avoid nu=0.5 (incompressible singularity in plane strain) tau_pl: float = 1.0, tau_pl_shear: float = 1.0, tau_pl_normal: float = 1.0, sigma_c_mean: float = 1.0, sigma_c_std: float = 0.1, n_fluid: float = 1.0, yield_criterion: str = "von_mises", n_bayesian_steps: int = 200, fluidity_form: str = "overstress", ): """Initialize the Tensorial EPM. Args: L: Lattice size (LxL grid). dt: Time step for integration. mu: Shear modulus. nu: Poisson's ratio for plane strain constraint. tau_pl: Base plastic relaxation time (for compatibility). tau_pl_shear: Plastic relaxation time for shear components. tau_pl_normal: Plastic relaxation time for normal stress components. sigma_c_mean: Mean yield threshold. sigma_c_std: Standard deviation of yield thresholds (disorder). yield_criterion: Yield criterion name ("von_mises" or "hill"). n_bayesian_steps: Number of time steps for Bayesian inference. Default 200. """ # Initialize base class with common parameters. # TensorialEPM uses the Prandtl-Reuss flow rule via its own kernel # (``rheojax.utils.epm_kernels_tensorial``). The kernel now supports # all three fluidity forms — see `compute_plastic_strain_rate` for # the mathematical definitions. "overstress" is the default and # recommended choice for yield-stress fluids (emulsions, gels, foams, # pastes); it produces a full Herschel-Bulkley flow curve with the # von-Mises pure-shear plateau at sigma_c_mean / sqrt(3). if fluidity_form not in ("linear", "power", "overstress"): raise ValueError( f"fluidity_form must be 'linear', 'power', or 'overstress'; " f"got {fluidity_form!r}." ) super().__init__( L=L, dt=dt, mu=mu, tau_pl=tau_pl, sigma_c_mean=sigma_c_mean, sigma_c_std=sigma_c_std, n_fluid=n_fluid, n_bayesian_steps=n_bayesian_steps, fluidity_form=fluidity_form, ) # Add tensorial-specific parameters self.parameters.add( "nu", nu, bounds=(0.3, 0.5), units="", description="Poisson's ratio for plane strain", ) self.parameters.add( "tau_pl_shear", tau_pl_shear, bounds=(0.01, 100.0), units="s", description="Plastic relaxation time for shear", ) self.parameters.add( "tau_pl_normal", tau_pl_normal, bounds=(0.01, 100.0), units="s", description="Plastic relaxation time for normal stresses", ) self.parameters.add( "w_N1", 1.0, bounds=(0.1, 10.0), units="", description="Weight for N₁ in combined fitting loss", ) self.parameters.add( "hill_H", 0.5, bounds=(0.1, 5.0), units="", description="Hill anisotropy parameter H", ) self.parameters.add( "hill_N", 1.5, bounds=(0.1, 5.0), units="", description="Hill anisotropy parameter N", ) # Yield criterion (static configuration) if yield_criterion not in ["von_mises", "hill"]: raise ValueError( f"Unknown yield criterion: {yield_criterion}. " "Must be 'von_mises' or 'hill'." ) self.yield_criterion = yield_criterion # Precompute tensorial propagator (cached) # Using mu=1.0 as normalization, will scale by actual mu during execution self._propagator_q_norm = make_tensorial_propagator_q(L, nu=nu, mu=1.0)
def _init_stress(self, key: jax.Array) -> jax.Array: """Initialize tensorial stress field. Args: key: PRNG key (unused for zero initialization). Returns: Zero-initialized stress tensor of shape (3, L, L) for [σ_xx, σ_yy, σ_xy]. """ # Start relaxed (zero stress) return jnp.zeros((3, self.L, self.L)) def _is_scalar_epm(self) -> bool: """Check if this is a scalar (not tensorial) EPM. Returns False for TensorialEPM (tensorial stress field). This causes model_function to use general methods instead of scalar JIT kernels. """ return False def _get_param_dict(self) -> dict[str, float]: """Extract parameters as dictionary for kernel calls. Extends base class method to include tensorial parameters. Returns: Dictionary with all EPM parameters including tensorial ones. """ base_params = super()._get_param_dict() # Add tensorial-specific parameters nu = self.parameters.get_value("nu") tau_pl_shear = self.parameters.get_value("tau_pl_shear") tau_pl_normal = self.parameters.get_value("tau_pl_normal") hill_H = self.parameters.get_value("hill_H") hill_N = self.parameters.get_value("hill_N") if nu is None: raise ValueError("Parameter 'nu' must be set before use") if tau_pl_shear is None: raise ValueError("Parameter 'tau_pl_shear' must be set before use") if tau_pl_normal is None: raise ValueError("Parameter 'tau_pl_normal' must be set before use") if hill_H is None: raise ValueError("Parameter 'hill_H' must be set before use") if hill_N is None: raise ValueError("Parameter 'hill_N' must be set before use") tensorial_params = { "nu": nu, "tau_pl_shear": tau_pl_shear, "tau_pl_normal": tau_pl_normal, "hill_H": hill_H, "hill_N": hill_N, } return {**base_params, **tensorial_params} def _epm_step( self, state: tuple[jax.Array, jax.Array, float, jax.Array], propagator_q: jax.Array, shear_rate: float, dt: float, params: dict, smooth: bool, ) -> tuple[jax.Array, jax.Array, float, jax.Array]: """Perform one tensorial EPM time step. Delegates to tensorial_epm_step kernel from epm_kernels_tensorial module. Args: state: Current state (stress, thresholds, strain, key). propagator_q: Precomputed tensorial propagator. shear_rate: Imposed macroscopic shear rate. dt: Time step size. params: Model parameters dictionary. smooth: Use smooth yielding (tanh) or hard threshold (step). Returns: Updated state tuple (new_stress, thresholds, new_strain, key). """ stress, thresholds, strain, key = state # Call tensorial kernel — forward the model's fluidity_form so that # NLSQ fitting, NUTS inference, and forward .predict() all use the # same constitutive law. new_stress = tensorial_epm_step( stress=stress, thresholds=thresholds, strain_rate=shear_rate, dt=dt, propagator=propagator_q, params=params, smooth=smooth, yield_criterion=self.yield_criterion, fluidity_form=self.fluidity_form, ) # Update accumulated strain new_strain = strain + shear_rate * dt return (new_stress, thresholds, new_strain, key)
[docs] def get_shear_stress(self, stress: jax.Array) -> jax.Array: """Extract shear stress component from stress tensor. Args: stress: Stress tensor of shape (3, L, L) or (..., 3, L, L). Returns: Shear stress σ_xy, shape (L, L) or (..., L, L). """ # Component ordering: [σ_xx, σ_yy, σ_xy] if stress.ndim == 3: return stress[2] else: return stress[..., 2, :, :]
[docs] def get_normal_stress_differences( self, stress: jax.Array, nu: float | None = None ) -> tuple[jax.Array, jax.Array]: """Compute normal stress differences from stress tensor. For plane strain: σ_zz = ν(σ_xx + σ_yy) Normal stress differences: - N₁ = σ_xx - σ_yy - N₂ = σ_yy - σ_zz Args: stress: Stress tensor of shape (3, L, L). nu: Poisson's ratio. If None, uses parameter value. Returns: Tuple (N₁, N₂), each of shape (L, L). """ if nu is None: nu = self.parameters.get_value("nu") sigma_xx = stress[0] sigma_yy = stress[1] # Plane strain constraint sigma_zz = nu * (sigma_xx + sigma_yy) N1 = sigma_xx - sigma_yy N2 = sigma_yy - sigma_zz return N1, N2
[docs] def predict_normal_stresses( self, data: RheoData, **kwargs ) -> tuple[jax.Array, jax.Array]: """Convenience method to predict normal stress differences. Runs the simulation and extracts N₁ and N₂ spatial averages over time. Args: data: RheoData with protocol specification. **kwargs: Additional arguments passed to predict(). Returns: Tuple (N₁_array, N₂_array) with time-averaged values. Raises: NotImplementedError: Not yet implemented (future feature). """ raise NotImplementedError( "predict_normal_stresses() not yet implemented. " "Use predict() with test_mode='flow_curve' which returns [σ_xy, N₁]." )
# Override flow_curve to return shear stress with N₁ in metadata def _run_flow_curve( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Steady state flow curve: Stress vs Shear Rate. For tensorial EPM, returns shear stress σ_xy with N₁ stored in metadata. Args: data: RheoData with x=shear_rates. key: PRNG key. propagator_q: Precomputed propagator. params: Model parameters. smooth: Use smooth yielding. Returns: RheoData with x=shear_rates, y=sigma_xy. metadata contains 'N1' with first normal stress differences. """ shear_rates = data.x sigma_c_mean = params["sigma_c_mean"] tau_pl_shear = params.get("tau_pl_shear", params.get("tau_pl", 1.0)) n_fluid = params.get("n_fluid", 1.0) def scan_fn(gdot): # Run simulation for sufficient steps to reach steady state n_steps = 1000 state = self._init_state(key) stress0, thresholds, strain, k = state # Warm-start sigma_xy at the analytical overstress steady state # for pure shear with von Mises yielding. # # The tensorial stress update is dσ_xy/dt = μγ̇ − 2μ·ε̇^p_xy (see the # Budrikis & Zapperi 2013 reference at the top of # epm_kernels_tensorial.py), so at steady state ε̇^p_xy = γ̇/2 in the # tensor convention. Solving the overstress Prandtl-Reuss rule for # pure shear gives: # # σ_xy = σ_c_mean/√3 # + (1/√3) · (√3 / 2)^(1/n_fluid) # · σ_c_mean^((n_fluid − 1)/n_fluid) # · (|γ̇|·τ_pl_shear)^(1/n_fluid) # # Special cases: # n_fluid = 1 → σ_xy = σ_c_mean/√3 + |γ̇|·τ_pl_shear / 2 # n_fluid = 2 → σ_xy = σ_c_mean/√3 # + √( σ_c_mean · |γ̇| · τ_pl_shear / (2·√3) ) # # Starting near this target avoids long transient loading times at # low shear rates and destabilising FFT redistribution at high # rates. Normal components (indices 0, 1) are left at zero — they # develop self-consistently from the Prandtl-Reuss flow. sqrt3 = jnp.sqrt(3.0) gdot_abs = jnp.abs(gdot) # Guard against pathological parameter combinations that NLSQ can # transiently probe during fitting (very small sigma_c_mean or # n_fluid, extreme shear rates). The clamps below keep the # analytical warm-start formula numerically well-defined without # affecting its value at reasonable parameters. scm_safe = jnp.maximum(sigma_c_mean, 1e-6) n_safe = jnp.maximum(n_fluid, 1e-3) tau_safe = jnp.maximum(tau_pl_shear, 1e-6) inv_n = 1.0 / n_safe excess_base = jnp.maximum(gdot_abs * tau_safe, 0.0) warm_excess = ( (1.0 / sqrt3) * (sqrt3 / 2.0) ** inv_n * scm_safe ** ((n_safe - 1.0) * inv_n) * excess_base**inv_n ) # Final clamp: if any component is NaN/inf, fall back to the # plateau value scm_safe / sqrt(3) so the simulation still runs. sigma_xy_warm_raw = jnp.sign(gdot) * (scm_safe / sqrt3 + warm_excess) sigma_xy_warm = jnp.where( jnp.isfinite(sigma_xy_warm_raw), sigma_xy_warm_raw, jnp.sign(gdot) * scm_safe / sqrt3, ) stress0 = stress0.at[2].add(sigma_xy_warm) state = (stress0, thresholds, strain, k) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, gdot, self.dt, params, smooth ) stress_tensor = new_state[0] # Shape (3, L, L) # Extract shear stress and N₁ sigma_xy_mean = jnp.mean(stress_tensor[2]) N1, _ = self.get_normal_stress_differences(stress_tensor, params["nu"]) N1_mean = jnp.mean(N1) return new_state, jnp.array([sigma_xy_mean, N1_mean]) _, history = jax.lax.scan(body, state, None, length=n_steps) # Average last 50% for steady state # history has shape (n_steps, 2) steady_values = jnp.mean(history[n_steps // 2 :], axis=0) return steady_values # [σ_xy, N₁] # Vectorize over shear rates stresses = jax.vmap(scan_fn)(shear_rates) # Shape: (n_rates, 2) # Extract components with NaN safety net. NLSQ fitting can transiently # probe parameter combinations that make the explicit-Euler kernel # unstable (e.g., very small sigma_c_mean with large n_fluid at high # gdot), producing NaN in the steady-state average. Replace those # entries with the analytical plateau sigma_c_mean/sqrt(3) so the # validation downstream does not explode — the optimiser will # naturally move away from NaN regions in the next iteration. sigma_xy_raw = stresses[:, 0] N1_raw = stresses[:, 1] scm_safe = jnp.maximum(params.get("sigma_c_mean", 1.0), 1e-6) plateau_fallback = scm_safe / jnp.sqrt(3.0) sigma_xy = jnp.where(jnp.isfinite(sigma_xy_raw), sigma_xy_raw, plateau_fallback) N1 = jnp.where(jnp.isfinite(N1_raw), N1_raw, 0.0) # Store N₁ in metadata result_metadata = data.metadata.copy() if data.metadata else {} result_metadata["N1"] = N1 return RheoData( x=shear_rates, y=sigma_xy, initial_test_mode="flow_curve", metadata=result_metadata, ) def _predict(self, X, **kwargs) -> RheoData: """Simulate the model for the given protocol. Args: X: Input data - can be RheoData or numpy/JAX array. kwargs: test_mode (str): 'flow_curve', 'startup', 'relaxation', 'creep', 'oscillation'. smooth (bool): Use smooth yielding (default False for simulation, True for fitting). seed (int): Random seed (default 0). Returns: RheoData with simulation results. - flow_curve: y is σ_xy array, metadata['N1'] contains N₁ values - Other protocols: y has shape (n_points,) with σ_xy only """ # Handle both RheoData and raw array input if isinstance(X, RheoData): rheo_data = X test_mode = kwargs.get("test_mode", rheo_data.test_mode) else: # Raw array input - wrap in RheoData test_mode = kwargs.get("test_mode") if test_mode is None: test_mode = getattr(self, "_test_mode", "flow_curve") x_array = jnp.asarray(X, dtype=jnp.float64) # Create dummy y for RheoData constructor dummy_y = jnp.zeros_like(x_array) metadata = {} # Copy cached metadata if hasattr(self, "_cached_gamma_dot"): metadata["gamma_dot"] = self._cached_gamma_dot if hasattr(self, "_cached_gamma"): metadata["gamma"] = self._cached_gamma if hasattr(self, "_cached_stress"): metadata["stress"] = self._cached_stress if hasattr(self, "_cached_gamma0"): metadata["gamma0"] = self._cached_gamma0 if hasattr(self, "_cached_omega"): metadata["omega"] = self._cached_omega rheo_data = RheoData( x=x_array, y=dummy_y, initial_test_mode=test_mode, metadata=metadata ) smooth = kwargs.get("smooth", False) seed = kwargs.get("seed", 0) key = jax.random.PRNGKey(seed) # Extract parameters # Scale propagator by current mu mu = self.parameters.get_value("mu") propagator_q = self._propagator_q_norm * mu # Get full parameter dictionary param_dict = self._get_param_dict() if test_mode == "flow_curve": return self._run_flow_curve( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "startup": return self._run_startup_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "relaxation": return self._run_relaxation_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "creep": return self._run_creep_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "oscillation": return self._run_oscillation_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) else: raise ValueError(f"Unknown test_mode: {test_mode}") def _run_startup_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Start-up shear: Stress(t) at constant rate. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data if possible dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Constant shear rate from metadata gdot = data.metadata.get("gamma_dot", 0.1) # Scan for N-1 steps n_steps = max(0, len(time) - 1) state = self._init_state(key) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, gdot, dt, params, smooth ) # Extract shear stress component (index 2) return new_state, jnp.mean(new_state[0][2]) if n_steps > 0: _, stresses_scan = jax.lax.scan(body, state, None, length=n_steps) # Prepend initial stress initial_stress = jnp.mean(state[0][2]) stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan]) else: stresses = jnp.array([jnp.mean(state[0][2])]) return RheoData(x=time, y=stresses, initial_test_mode="startup") def _run_relaxation_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Stress relaxation: G(t) after step strain. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Step strain magnitude from metadata strain_step = data.metadata.get("gamma", 0.1) state = self._init_state(key) stress, thresh, strain, k = state # Apply Step Strain (Elastic Load) - only to shear component mu = params["mu"] stress = stress.at[2].set(stress[2] + mu * strain_step) state = (stress, thresh, strain + strain_step, k) # Initial G(0) g_0 = jnp.mean(stress[2]) / strain_step # Relax (gdot = 0) for N-1 steps n_steps = max(0, len(time) - 1) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, 0.0, dt, params, smooth ) # Return G(t) = Shear Stress / gamma_0 return new_state, jnp.mean(new_state[0][2]) / strain_step if n_steps > 0: _, moduli_scan = jax.lax.scan(body, state, None, length=n_steps) moduli = jnp.concatenate([jnp.array([g_0]), moduli_scan]) else: moduli = jnp.array([g_0]) return RheoData(x=time, y=moduli, initial_test_mode="relaxation") def _run_creep_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Creep: Strain(t) at constant stress using Adaptive P-Controller. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Target stress: metadata is canonical; fall back to mean(y) only when # the caller passed y=constant (legacy pattern). See the matching # scalar fix in rheojax/models/epm/base.py::_run_creep. target_stress = data.metadata.get("stress") if data.metadata else None if target_stress is None: if data.y is not None and data.y.size > 0: y_mean = float(jnp.mean(data.y)) target_stress = y_mean if abs(y_mean) > 1e-12 else 1.0 else: target_stress = 1.0 target_stress = float(target_stress) # Controller Params Kp_base = 0.01 alpha = 10.0 state = self._init_state(key) # Augmented state: (EPM_State, current_gdot) aug_state = (state, 0.0) # Initial strain (0.0) initial_strain = state[2] n_steps = max(0, len(time) - 1) def body(carrier, _): curr_epm, gdot = carrier stress_grid = curr_epm[0] # Extract shear stress component curr_stress = jnp.mean(stress_grid[2]) # Adaptive Control error = target_stress - curr_stress # Gain scheduling: Boost gain if error is large relative to target rel_error = jnp.abs(error) / (jnp.abs(target_stress) + 1e-6) Kp = Kp_base * (1.0 + alpha * rel_error) # Update shear rate (P-control on rate) gdot_new = gdot + Kp * error # Prevent negative shear rate gdot_new = jnp.maximum(gdot_new, 0.0) # Step EPM new_epm = self._epm_step( curr_epm, propagator_q, gdot_new, dt, params, smooth ) # Return Strain return (new_epm, gdot_new), new_epm[2] if n_steps > 0: _, strains_scan = jax.lax.scan(body, aug_state, None, length=n_steps) strains = jnp.concatenate([jnp.array([initial_strain]), strains_scan]) else: strains = jnp.array([initial_strain]) return RheoData(x=time, y=strains, initial_test_mode="creep") def _run_oscillation_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """SAOS/LAOS: Stress(t) for sinusoidal strain. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Params gamma0 = data.metadata.get("gamma0", 1.0) omega = data.metadata.get("omega", 1.0) state = self._init_state(key) # Initial stress initial_stress = jnp.mean(state[0][2]) # Run for N-1 steps n_steps = max(0, len(time) - 1) scan_time = time[:-1] if n_steps > 0 else jnp.array([]) def body(carrier, t): curr_state = carrier # Time varying shear rate at current time t gdot = gamma0 * omega * jnp.cos(omega * t) new_state = self._epm_step( curr_state, propagator_q, gdot, dt, params, smooth ) # Extract shear stress component return new_state, jnp.mean(new_state[0][2]) if n_steps > 0: _, stresses_scan = jax.lax.scan(body, state, scan_time, length=n_steps) stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan]) else: stresses = jnp.array([initial_stress]) return RheoData(x=time, y=stresses, initial_test_mode="oscillation") def _fit(self, X, y, **kwargs): """Fit tensorial-EPM parameters to shear-stress data. Currently supports **shear-only fitting** (y is 1D, matching the mean σ_xy(γ̇) for flow curves or σ_xy(t) for time-domain protocols). Under the hood this delegates to ``EPMBase._fit`` which runs NLSQ against the JAX-pure ``_model_*`` methods; the base-class ``_model_flow_curve`` and sister methods extract σ_xy from the tensorial (3, L, L) stress field via ``_mean_shear_stress`` so the model function returns a 1D array of shear stresses compatible with the user's 1D y. Joint fitting of [σ_xy, N₁] (combined mode with w_N1 weighting) is not yet supported — reshape to shear-only if your data includes N₁. Args: X: Shear rates (flow curve) or time array (time-domain protocols). y: 1D target data (mean σ_xy or modulus). **kwargs: Forwarded to ``EPMBase._fit``. See its docstring for supported options (``test_mode``, ``use_log_residuals``, ``max_iter``, ``ftol``, ``xtol``, protocol kwargs). Returns: self for method chaining. """ # Delegate to EPMBase._fit. The base fit path uses self.model_function # which calls _model_function_general for tensorial (because # _is_scalar_epm() returns False), which in turn calls the generic # _model_flow_curve / _model_startup / ... in base.py. Those methods # now branch on stress.ndim to extract σ_xy correctly from the # tensorial (3, L, L) stress field. import jax.numpy as _jnp y_arr = _jnp.asarray(y) if y_arr.ndim > 1: raise NotImplementedError( f"TensorialEPM currently supports shear-only fitting " f"(1D y of mean σ_xy). Got y with shape {y_arr.shape}. " f"For combined [σ_xy, N₁] fitting, provide just the σ_xy " f"row and use .predict() afterwards to inspect N₁." ) return super()._fit(X, y_arr, **kwargs)