Source code for rheojax.models.epm.tensor

"""Tensorial Elasto-Plastic Model (EPM) implementation.

This module implements the full tensorial (3-component) stress formulation for EPM
simulations. It tracks the stress tensor [σ_xx, σ_yy, σ_xy] in 2D plane strain,
enabling prediction of normal stress differences (N₁, N₂), anisotropic flow behavior,
and kinematic hardening.

Key Features:
- Full tensorial stress state per lattice site
- Von Mises and Hill anisotropic yield criteria
- Normal stress difference predictions (N₁, N₂)
- Flexible fitting: shear-only or combined [σ_xy, N₁]
"""

from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.epm.base import EPMBase
from rheojax.utils.epm_kernels_tensorial import (
    make_tensorial_propagator_q,
    tensorial_epm_step,
)

jax, jnp = safe_import_jax()


[docs] @ModelRegistry.register( "tensorial_epm", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class TensorialEPM(EPMBase): """3-Component Tensorial Lattice EPM. A mesoscopic model for amorphous solids that explicitly tracks the full stress tensor, enabling predictions of normal stress differences and anisotropic flow. Physics: - Lattice of elastoplastic blocks with tensorial stress state. - Elastic loading (affine) for all stress components. - Von Mises or Hill yield criterion for anisotropic materials. - Component-wise plastic flow rule (Prandtl-Reuss). - Tensorial Eshelby propagator for stress redistribution. Parameters: mu (float): Shear modulus. Default 1.0. nu (float): Poisson's ratio for plane strain. Default 0.48 (avoid 0.5 singularity). tau_pl (float): Base plastic relaxation timescale (legacy). Default 1.0. tau_pl_shear (float): Plastic relaxation time for shear. Default 1.0. tau_pl_normal (float): Plastic relaxation time for normal stresses. Default 1.0. sigma_c_mean (float): Mean yield threshold. Default 1.0. sigma_c_std (float): Disorder strength (std dev of thresholds). Default 0.1. smoothing_width (float): Width for smooth yielding approx (inference only). Default 0.1. w_N1 (float): Weight for N₁ in combined fitting loss. Default 1.0. hill_H (float): Hill anisotropy parameter H. Default 0.5. hill_N (float): Hill anisotropy parameter N. Default 1.5. Configuration: L (int): Lattice size (LxL). Default 64. dt (float): Time step. Default 0.01. yield_criterion (str): "von_mises" or "hill". Default "von_mises". """
[docs] def __init__( self, L: int = 64, dt: float = 0.01, mu: float = 1.0, nu: float = 0.48, # Avoid nu=0.5 (incompressible singularity in plane strain) tau_pl: float = 1.0, tau_pl_shear: float = 1.0, tau_pl_normal: float = 1.0, sigma_c_mean: float = 1.0, sigma_c_std: float = 0.1, yield_criterion: str = "von_mises", n_bayesian_steps: int = 200, ): """Initialize the Tensorial EPM. Args: L: Lattice size (LxL grid). dt: Time step for integration. mu: Shear modulus. nu: Poisson's ratio for plane strain constraint. tau_pl: Base plastic relaxation time (for compatibility). tau_pl_shear: Plastic relaxation time for shear components. tau_pl_normal: Plastic relaxation time for normal stress components. sigma_c_mean: Mean yield threshold. sigma_c_std: Standard deviation of yield thresholds (disorder). yield_criterion: Yield criterion name ("von_mises" or "hill"). n_bayesian_steps: Number of time steps for Bayesian inference. Default 200. """ # Initialize base class with common parameters super().__init__( L=L, dt=dt, mu=mu, tau_pl=tau_pl, sigma_c_mean=sigma_c_mean, sigma_c_std=sigma_c_std, n_bayesian_steps=n_bayesian_steps, ) # Add tensorial-specific parameters self.parameters.add( "nu", nu, bounds=(0.3, 0.5), units="", description="Poisson's ratio for plane strain", ) self.parameters.add( "tau_pl_shear", tau_pl_shear, bounds=(0.01, 100.0), units="s", description="Plastic relaxation time for shear", ) self.parameters.add( "tau_pl_normal", tau_pl_normal, bounds=(0.01, 100.0), units="s", description="Plastic relaxation time for normal stresses", ) self.parameters.add( "w_N1", 1.0, bounds=(0.1, 10.0), units="", description="Weight for N₁ in combined fitting loss", ) self.parameters.add( "hill_H", 0.5, bounds=(0.1, 5.0), units="", description="Hill anisotropy parameter H", ) self.parameters.add( "hill_N", 1.5, bounds=(0.1, 5.0), units="", description="Hill anisotropy parameter N", ) # Yield criterion (static configuration) if yield_criterion not in ["von_mises", "hill"]: raise ValueError( f"Unknown yield criterion: {yield_criterion}. " "Must be 'von_mises' or 'hill'." ) self.yield_criterion = yield_criterion # Precompute tensorial propagator (cached) # Using mu=1.0 as normalization, will scale by actual mu during execution self._propagator_q_norm = make_tensorial_propagator_q(L, nu=nu, mu=1.0)
def _init_stress(self, key: jax.Array) -> jax.Array: """Initialize tensorial stress field. Args: key: PRNG key (unused for zero initialization). Returns: Zero-initialized stress tensor of shape (3, L, L) for [σ_xx, σ_yy, σ_xy]. """ # Start relaxed (zero stress) return jnp.zeros((3, self.L, self.L)) def _is_scalar_epm(self) -> bool: """Check if this is a scalar (not tensorial) EPM. Returns False for TensorialEPM (tensorial stress field). This causes model_function to use general methods instead of scalar JIT kernels. """ return False def _get_param_dict(self) -> dict[str, float]: """Extract parameters as dictionary for kernel calls. Extends base class method to include tensorial parameters. Returns: Dictionary with all EPM parameters including tensorial ones. """ base_params = super()._get_param_dict() # Add tensorial-specific parameters nu = self.parameters.get_value("nu") tau_pl_shear = self.parameters.get_value("tau_pl_shear") tau_pl_normal = self.parameters.get_value("tau_pl_normal") hill_H = self.parameters.get_value("hill_H") hill_N = self.parameters.get_value("hill_N") if nu is None: raise ValueError("Parameter 'nu' must be set before use") if tau_pl_shear is None: raise ValueError("Parameter 'tau_pl_shear' must be set before use") if tau_pl_normal is None: raise ValueError("Parameter 'tau_pl_normal' must be set before use") if hill_H is None: raise ValueError("Parameter 'hill_H' must be set before use") if hill_N is None: raise ValueError("Parameter 'hill_N' must be set before use") tensorial_params = { "nu": nu, "tau_pl_shear": tau_pl_shear, "tau_pl_normal": tau_pl_normal, "hill_H": hill_H, "hill_N": hill_N, } return {**base_params, **tensorial_params} def _epm_step( self, state: tuple[jax.Array, jax.Array, float, jax.Array], propagator_q: jax.Array, shear_rate: float, dt: float, params: dict, smooth: bool, ) -> tuple[jax.Array, jax.Array, float, jax.Array]: """Perform one tensorial EPM time step. Delegates to tensorial_epm_step kernel from epm_kernels_tensorial module. Args: state: Current state (stress, thresholds, strain, key). propagator_q: Precomputed tensorial propagator. shear_rate: Imposed macroscopic shear rate. dt: Time step size. params: Model parameters dictionary. smooth: Use smooth yielding (tanh) or hard threshold (step). Returns: Updated state tuple (new_stress, thresholds, new_strain, key). """ stress, thresholds, strain, key = state # Call tensorial kernel new_stress = tensorial_epm_step( stress=stress, thresholds=thresholds, strain_rate=shear_rate, dt=dt, propagator=propagator_q, params=params, smooth=smooth, yield_criterion=self.yield_criterion, ) # Update accumulated strain new_strain = strain + shear_rate * dt return (new_stress, thresholds, new_strain, key)
[docs] def get_shear_stress(self, stress: jax.Array) -> jax.Array: """Extract shear stress component from stress tensor. Args: stress: Stress tensor of shape (3, L, L) or (..., 3, L, L). Returns: Shear stress σ_xy, shape (L, L) or (..., L, L). """ # Component ordering: [σ_xx, σ_yy, σ_xy] if stress.ndim == 3: return stress[2] else: return stress[..., 2, :, :]
[docs] def get_normal_stress_differences( self, stress: jax.Array, nu: float | None = None ) -> tuple[jax.Array, jax.Array]: """Compute normal stress differences from stress tensor. For plane strain: σ_zz = ν(σ_xx + σ_yy) Normal stress differences: - N₁ = σ_xx - σ_yy - N₂ = σ_yy - σ_zz Args: stress: Stress tensor of shape (3, L, L). nu: Poisson's ratio. If None, uses parameter value. Returns: Tuple (N₁, N₂), each of shape (L, L). """ if nu is None: nu = self.parameters.get_value("nu") sigma_xx = stress[0] sigma_yy = stress[1] # Plane strain constraint sigma_zz = nu * (sigma_xx + sigma_yy) N1 = sigma_xx - sigma_yy N2 = sigma_yy - sigma_zz return N1, N2
[docs] def predict_normal_stresses( self, data: RheoData, **kwargs ) -> tuple[jax.Array, jax.Array]: """Convenience method to predict normal stress differences. Runs the simulation and extracts N₁ and N₂ spatial averages over time. Args: data: RheoData with protocol specification. **kwargs: Additional arguments passed to predict(). Returns: Tuple (N₁_array, N₂_array) with time-averaged values. Raises: NotImplementedError: Not yet implemented (future feature). """ raise NotImplementedError( "predict_normal_stresses() not yet implemented. " "Use predict() with test_mode='flow_curve' which returns [σ_xy, N₁]." )
# Override flow_curve to return shear stress with N₁ in metadata def _run_flow_curve( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Steady state flow curve: Stress vs Shear Rate. For tensorial EPM, returns shear stress σ_xy with N₁ stored in metadata. Args: data: RheoData with x=shear_rates. key: PRNG key. propagator_q: Precomputed propagator. params: Model parameters. smooth: Use smooth yielding. Returns: RheoData with x=shear_rates, y=sigma_xy. metadata contains 'N1' with first normal stress differences. """ shear_rates = data.x def scan_fn(gdot): # Run simulation for sufficient steps to reach steady state n_steps = 1000 state = self._init_state(key) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, gdot, self.dt, params, smooth ) stress_tensor = new_state[0] # Shape (3, L, L) # Extract shear stress and N₁ sigma_xy_mean = jnp.mean(stress_tensor[2]) N1, _ = self.get_normal_stress_differences(stress_tensor, params["nu"]) N1_mean = jnp.mean(N1) return new_state, jnp.array([sigma_xy_mean, N1_mean]) _, history = jax.lax.scan(body, state, None, length=n_steps) # Average last 50% for steady state # history has shape (n_steps, 2) steady_values = jnp.mean(history[n_steps // 2 :], axis=0) return steady_values # [σ_xy, N₁] # Vectorize over shear rates stresses = jax.vmap(scan_fn)(shear_rates) # Shape: (n_rates, 2) # Extract components sigma_xy = stresses[:, 0] N1 = stresses[:, 1] # Store N₁ in metadata result_metadata = data.metadata.copy() if data.metadata else {} result_metadata["N1"] = N1 return RheoData( x=shear_rates, y=sigma_xy, initial_test_mode="flow_curve", metadata=result_metadata, ) def _predict(self, X, **kwargs) -> RheoData: """Simulate the model for the given protocol. Args: X: Input data - can be RheoData or numpy/JAX array. kwargs: test_mode (str): 'flow_curve', 'startup', 'relaxation', 'creep', 'oscillation'. smooth (bool): Use smooth yielding (default False for simulation, True for fitting). seed (int): Random seed (default 0). Returns: RheoData with simulation results. - flow_curve: y is σ_xy array, metadata['N1'] contains N₁ values - Other protocols: y has shape (n_points,) with σ_xy only """ # Handle both RheoData and raw array input if isinstance(X, RheoData): rheo_data = X test_mode = kwargs.get("test_mode", rheo_data.test_mode) else: # Raw array input - wrap in RheoData test_mode = kwargs.get("test_mode") if test_mode is None: test_mode = getattr(self, "_test_mode", "flow_curve") x_array = jnp.asarray(X, dtype=jnp.float64) # Create dummy y for RheoData constructor dummy_y = jnp.zeros_like(x_array) metadata = {} # Copy cached metadata if hasattr(self, "_cached_gamma_dot"): metadata["gamma_dot"] = self._cached_gamma_dot if hasattr(self, "_cached_gamma"): metadata["gamma"] = self._cached_gamma if hasattr(self, "_cached_stress"): metadata["stress"] = self._cached_stress if hasattr(self, "_cached_gamma0"): metadata["gamma0"] = self._cached_gamma0 if hasattr(self, "_cached_omega"): metadata["omega"] = self._cached_omega rheo_data = RheoData( x=x_array, y=dummy_y, initial_test_mode=test_mode, metadata=metadata ) smooth = kwargs.get("smooth", False) seed = kwargs.get("seed", 0) key = jax.random.PRNGKey(seed) # Extract parameters # Scale propagator by current mu mu = self.parameters.get_value("mu") propagator_q = self._propagator_q_norm * mu # Get full parameter dictionary param_dict = self._get_param_dict() if test_mode == "flow_curve": return self._run_flow_curve( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "startup": return self._run_startup_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "relaxation": return self._run_relaxation_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "creep": return self._run_creep_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) elif test_mode == "oscillation": return self._run_oscillation_tensorial( rheo_data, key, propagator_q, param_dict, smooth ) else: raise ValueError(f"Unknown test_mode: {test_mode}") def _run_startup_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Start-up shear: Stress(t) at constant rate. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data if possible dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Constant shear rate from metadata gdot = data.metadata.get("gamma_dot", 0.1) # Scan for N-1 steps n_steps = max(0, len(time) - 1) state = self._init_state(key) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, gdot, dt, params, smooth ) # Extract shear stress component (index 2) return new_state, jnp.mean(new_state[0][2]) if n_steps > 0: _, stresses_scan = jax.lax.scan(body, state, None, length=n_steps) # Prepend initial stress initial_stress = jnp.mean(state[0][2]) stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan]) else: stresses = jnp.array([jnp.mean(state[0][2])]) return RheoData(x=time, y=stresses, initial_test_mode="startup") def _run_relaxation_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Stress relaxation: G(t) after step strain. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Step strain magnitude from metadata strain_step = data.metadata.get("gamma", 0.1) state = self._init_state(key) stress, thresh, strain, k = state # Apply Step Strain (Elastic Load) - only to shear component mu = params["mu"] stress = stress.at[2].set(stress[2] + mu * strain_step) state = (stress, thresh, strain + strain_step, k) # Initial G(0) g_0 = jnp.mean(stress[2]) / strain_step # Relax (gdot = 0) for N-1 steps n_steps = max(0, len(time) - 1) def body(carrier, _): curr_state = carrier new_state = self._epm_step( curr_state, propagator_q, 0.0, dt, params, smooth ) # Return G(t) = Shear Stress / gamma_0 return new_state, jnp.mean(new_state[0][2]) / strain_step if n_steps > 0: _, moduli_scan = jax.lax.scan(body, state, None, length=n_steps) moduli = jnp.concatenate([jnp.array([g_0]), moduli_scan]) else: moduli = jnp.array([g_0]) return RheoData(x=time, y=moduli, initial_test_mode="relaxation") def _run_creep_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """Creep: Strain(t) at constant stress using Adaptive P-Controller. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Target stress from metadata or mean of y if data.y is not None: target_stress = jnp.mean(data.y) else: target_stress = data.metadata.get("stress", 1.0) # Controller Params Kp_base = 0.01 alpha = 10.0 state = self._init_state(key) # Augmented state: (EPM_State, current_gdot) aug_state = (state, 0.0) # Initial strain (0.0) initial_strain = state[2] n_steps = max(0, len(time) - 1) def body(carrier, _): curr_epm, gdot = carrier stress_grid = curr_epm[0] # Extract shear stress component curr_stress = jnp.mean(stress_grid[2]) # Adaptive Control error = target_stress - curr_stress # Gain scheduling: Boost gain if error is large relative to target rel_error = jnp.abs(error) / (jnp.abs(target_stress) + 1e-6) Kp = Kp_base * (1.0 + alpha * rel_error) # Update shear rate (P-control on rate) gdot_new = gdot + Kp * error # Prevent negative shear rate gdot_new = jnp.maximum(gdot_new, 0.0) # Step EPM new_epm = self._epm_step( curr_epm, propagator_q, gdot_new, dt, params, smooth ) # Return Strain return (new_epm, gdot_new), new_epm[2] if n_steps > 0: _, strains_scan = jax.lax.scan(body, aug_state, None, length=n_steps) strains = jnp.concatenate([jnp.array([initial_strain]), strains_scan]) else: strains = jnp.array([initial_strain]) return RheoData(x=time, y=strains, initial_test_mode="creep") def _run_oscillation_tensorial( self, data: RheoData, key: jax.Array, propagator_q: jax.Array, params: dict, smooth: bool, ) -> RheoData: """SAOS/LAOS: Stress(t) for sinusoidal strain. Extracts shear component from tensorial stress. """ time = data.x if time is None: raise ValueError("data.x (time array) must not be None") # Calculate dt from data dt = self.dt if len(time) > 1: dt = float(time[1] - time[0]) # Params gamma0 = data.metadata.get("gamma0", 1.0) omega = data.metadata.get("omega", 1.0) state = self._init_state(key) # Initial stress initial_stress = jnp.mean(state[0][2]) # Run for N-1 steps n_steps = max(0, len(time) - 1) scan_time = time[:-1] if n_steps > 0 else jnp.array([]) def body(carrier, t): curr_state = carrier # Time varying shear rate at current time t gdot = gamma0 * omega * jnp.cos(omega * t) new_state = self._epm_step( curr_state, propagator_q, gdot, dt, params, smooth ) # Extract shear stress component return new_state, jnp.mean(new_state[0][2]) if n_steps > 0: _, stresses_scan = jax.lax.scan(body, state, scan_time, length=n_steps) stresses = jnp.concatenate([jnp.array([initial_stress]), stresses_scan]) else: stresses = jnp.array([initial_stress]) return RheoData(x=time, y=stresses, initial_test_mode="oscillation") def _fit(self, X, y, **kwargs): """Fit model parameters to data with flexible target selection. Supports: - 1D y: Fit to shear stress σ_xy only (backward compatible) - 2D y with shape (2, n): Fit to [σ_xy, N₁] simultaneously - 3D y: Not yet supported (full tensor fitting) Args: X: Shear rates or time array. y: Target data (1D or 2D array). **kwargs: test_mode (str): Protocol type (default 'flow_curve'). Other fitting parameters. Raises: NotImplementedError: EPM fitting is complex and not yet fully implemented. """ # Auto-detect fitting mode from y shape if y.ndim == 1: # Shear-only fitting (backward compatible) fitting_mode = "shear_only" elif y.ndim == 2 and y.shape[0] == 2: # Combined fitting [σ_xy, N₁] fitting_mode = "combined" # w_N1 = self.parameters.get_value("w_N1") # TODO: Use when fitting implemented elif y.ndim == 2 and y.shape[0] == 3: raise NotImplementedError( "Full tensor fitting (3 components) not yet supported. " "Use 1D y for shear-only or 2D y with shape (2, n) for [σ_xy, N₁]." ) else: raise ValueError( f"Invalid y shape: {y.shape}. " "Expected 1D for shear-only or (2, n) for [σ_xy, N₁]." ) # Fitting requires smooth approximation and gradient-based optimization # This is complex for EPM and requires careful implementation raise NotImplementedError( f"TensorialEPM fitting (mode: {fitting_mode}) not yet implemented. " "EPM parameter inference requires MCMC or specialized optimization. " "Use model.predict() for forward simulations." )