Source code for rheojax.models.sgr.sgr_generic

"""Soft Glassy Rheology (SGR) GENERIC Thermodynamic Framework Model.

This module implements the GENERIC (General Equation for Non-Equilibrium
Reversible-Irreversible Coupling) thermodynamic framework for the SGR model,
based on Fuereder & Ilg (2013) Physical Review E 88, 042134.

The GENERIC framework provides a thermodynamically consistent formulation by
splitting the dynamics into two parts:

1. Reversible (Hamiltonian) dynamics:
    dz/dt|_rev = L(z) * dF/dz

    where L is the antisymmetric Poisson bracket operator that generates
    reversible dynamics conserving energy.

2. Irreversible (dissipative) dynamics:
    dz/dt|_irrev = M(z) * dS/dz

    where M is the symmetric positive semi-definite friction matrix that
    generates entropy-producing irreversible dynamics.

The full GENERIC dynamics is:
    dz/dt = L(z) * dF/dz + M(z) * dS/dz

Key thermodynamic constraints:
- Entropy production: W = (dS/dz)^T M (dS/dz) >= 0 (second law)
- Energy conservation in reversible part: L * dS/dz = 0
- Entropy conservation in reversible part: L^T * dF/dz = 0
- Degeneracy conditions: L * dS/dz = M * dF/dz = 0

State Variables:
    For SGR, the GENERIC state vector z contains:
    - sigma: Stress (momentum-like variable conjugate to strain)
    - P(E,l): Trap occupation distribution (structural variable)

    In the simplified formulation used here:
    - z[0] = sigma: Macroscopic stress
    - z[1] = lambda: Structural parameter (0 = broken, 1 = intact)

Physical Interpretation:
    The GENERIC framework ensures that the SGR model satisfies fundamental
    thermodynamic laws: energy conservation (first law) and entropy production
    (second law). The Poisson bracket encodes the reversible coupling between
    stress and strain rate (Hamiltonian mechanics), while the friction matrix
    encodes the irreversible trap hopping dynamics that produces entropy.

References:
    - I. Fuereder and P. Ilg, GENERIC framework for the Fokker-Planck equation,
      Physical Review E, 2013, 88, 042134
    - P. Sollich, Rheological constitutive equation for a model of soft glassy
      materials, Physical Review E, 1998, 58(1), 738-759
    - H.C. Ottinger, Beyond Equilibrium Thermodynamics, Wiley, 2005
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.sgr_kernels import G0, Gp

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

if TYPE_CHECKING:  # pragma: no cover
    import jax.numpy as jnp_typing
else:
    jnp_typing = np

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "sgr_generic", protocols=[ Protocol.FLOW_CURVE, Protocol.CREEP, Protocol.RELAXATION, Protocol.STARTUP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class SGRGeneric(BaseModel): """Soft Glassy Rheology (SGR) GENERIC Thermodynamic Framework Model. This model implements the GENERIC (General Equation for Non-Equilibrium Reversible-Irreversible Coupling) thermodynamic framework for SGR, ensuring thermodynamic consistency via explicit entropy production tracking. The GENERIC formulation splits dynamics into: - Reversible (Hamiltonian): dz/dt = L * dF/dz (Poisson bracket L antisymmetric) - Irreversible (dissipative): dz/dt = M * dS/dz (friction M symmetric PSD) Parameters: x: Effective noise temperature (dimensionless), controls phase transition G0: Modulus scale (Pa), sets absolute magnitude of elastic response tau0: Attempt time (s), characteristic microscopic relaxation timescale State Variables: z = [sigma, lambda] where: - sigma: Macroscopic stress (Pa) - lambda: Structural parameter [0, 1] representing trap occupation Thermodynamic Functions: - F(z): Helmholtz free energy = U(z) - T*S(z) - U(z): Internal energy from elastic storage - S(z): Entropy from trap distribution - W: Entropy production rate = (dF/dz)^T M (dF/dz) >= 0 Example: >>> from rheojax.models.sgr_generic import SGRGeneric >>> import numpy as np >>> model = SGRGeneric() >>> omega = np.logspace(-2, 2, 50) >>> model._test_mode = 'oscillation' >>> G_star = model.predict(omega) >>> # Check thermodynamic consistency >>> state = np.array([100.0, 0.5]) >>> W = model.compute_entropy_production(state) >>> assert W >= 0, "Second law violated!" Notes: - Inherits from BaseModel (includes BayesianMixin for NumPyro NUTS) - Predictions match SGRConventional in linear viscoelastic regime - GENERIC structure guarantees thermodynamic consistency - Reference: Fuereder & Ilg 2013 PRE 88, 042134 """
[docs] def __init__(self, dynamic_x: bool = False): """Initialize SGR GENERIC Model. Creates ParameterSet with: - x (noise temperature): bounds (0.5, 3.0), default 1.5 - G0 (modulus scale): bounds (1e-3, 1e9), default 1e3 - tau0 (attempt time): bounds (1e-9, 1e3), default 1e-3 Args: dynamic_x: If True, enable dynamic noise temperature evolution with 3D state [sigma, lambda, x]. Default False for backward compatibility. """ super().__init__() # Create parameter set (same as SGRConventional for compatibility) self.parameters = ParameterSet() # x: Effective noise temperature (dimensionless) self.parameters.add( name="x", value=1.5, bounds=(0.5, 3.0), units="dimensionless", description="Effective noise temperature (glass transition at x=1)", ) # G0: Modulus scale (Pa) self.parameters.add( name="G0", value=1e3, bounds=(1e-3, 1e9), units="Pa", description="Modulus scale (absolute magnitude of elastic response)", ) # tau0: Attempt time (s) self.parameters.add( name="tau0", value=1e-3, bounds=(1e-9, 1e3), units="s", description="Attempt time (microscopic relaxation timescale)", ) # Store test mode for mode-aware Bayesian inference self._test_mode: TestMode | str | None = None # Storage for entropy production tracking self._cumulative_entropy_production: float = 0.0 # Internal flags for extended features self._thixotropy_enabled: bool = False self._dynamic_x: bool = dynamic_x # Storage for LAOS parameters self._gamma_0: float | None = None self._omega_laos: float | None = None # Storage for lambda trajectory (thixotropy) self._lambda_trajectory: np.ndarray | None = None # Initialize dynamic x parameters if enabled if dynamic_x: self._init_dynamic_x_parameters()
# ========================================================================= # GENERIC State Variables and Thermodynamic Functions # =========================================================================
[docs] def free_energy(self, state: np.ndarray) -> float: """Compute Helmholtz free energy F(z) = U(z) - T*S(z). The free energy functional for SGR combines: - Elastic energy storage from stressed trap elements - Entropic contribution from trap occupation distribution Args: state: State vector [sigma, lambda] where sigma is stress (Pa) and lambda is structural parameter [0, 1] Returns: Free energy F (J/m^3 or Pa, depending on normalization) Notes: F = U - T*S where T is the noise temperature x (in units of trap depth) """ U = self.internal_energy(state) S = self.entropy(state) T = self.parameters.get_value("x") # Noise temperature as effective temperature assert T is not None return U - T * S
[docs] def internal_energy(self, state: np.ndarray) -> float: """Compute internal energy U(z) from elastic storage. The internal energy represents energy stored in elastically deformed trap elements. For SGR with stress sigma and structural parameter lambda: U = (1/2) * (sigma^2 / (G0 * lambda^n)) where the effective modulus depends on structure. Args: state: State vector [sigma, lambda] Returns: Internal energy U (J/m^3) """ sigma = state[0] lam = np.clip(state[1], 0.01, 1.0) # Prevent division by zero G0_val = self.parameters.get_value("G0") x = self.parameters.get_value("x") assert G0_val is not None assert x is not None # Compute dimensionless equilibrium modulus G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability # Effective modulus depends on structure G_eff = G0_val * G0_dim * lam # Elastic energy: U = sigma^2 / (2 * G_eff) U = sigma**2 / (2.0 * G_eff + 1e-20) return U
[docs] def entropy(self, state: np.ndarray) -> float: """Compute entropy S(z) from trap occupation distribution. The entropy represents the configurational entropy of trap occupation. For the structural parameter lambda in [0, 1], we use a mixing entropy form: S = -k * [lambda * ln(lambda) + (1-lambda) * ln(1-lambda)] This captures the entropy associated with the distribution of elements between trapped (structured) and free (unstructured) states. Args: state: State vector [sigma, lambda] Returns: Entropy S (dimensionless, normalized by kB) """ lam = np.clip(state[1], 1e-10, 1.0 - 1e-10) # Prevent log(0) # Binary mixing entropy (normalized by characteristic scale) S = -(lam * np.log(lam) + (1.0 - lam) * np.log(1.0 - lam)) return S
# ========================================================================= # GENERIC Operators: Poisson Bracket L and Friction Matrix M # =========================================================================
[docs] def poisson_bracket(self, state: np.ndarray) -> np.ndarray: """Compute Poisson bracket operator L(z). The Poisson bracket generates reversible (Hamiltonian) dynamics. It must be antisymmetric: L = -L^T. For SGR, the Poisson bracket couples stress sigma to strain rate: L = [[0, L_12], [-L_12, 0]] where L_12 encodes the stress-strain rate coupling from the constitutive relation. Args: state: State vector [sigma, lambda] Returns: 2x2 antisymmetric Poisson bracket matrix L Notes: - L is state-dependent in general - Antisymmetry ensures energy conservation: dE/dt = 0 for reversible part """ lam = np.clip(state[1], 0.01, 1.0) G0_val = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") x = self.parameters.get_value("x") assert G0_val is not None assert tau0 is not None assert x is not None G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability # Coupling strength for stress-strain relationship # L_12 ~ G_eff / tau0 for Maxwell-like coupling G_eff = G0_val * G0_dim * lam L_12 = G_eff / tau0 # Antisymmetric Poisson bracket L = np.array([[0.0, L_12], [-L_12, 0.0]]) return L
[docs] def friction_matrix(self, state: np.ndarray) -> np.ndarray: """Compute friction matrix M(z). The friction matrix generates irreversible (dissipative) dynamics. It must be symmetric and positive semi-definite: M = M^T, M >= 0. For SGR, the friction matrix encodes: - Viscous dissipation (stress relaxation) - Structural evolution (trap hopping) Args: state: State vector [sigma, lambda] Returns: 2x2 symmetric positive semi-definite friction matrix M Notes: - M is state-dependent - Positive semi-definiteness ensures entropy production W >= 0 - The noise temperature x appears in M controlling dissipation rate """ lam = np.clip(state[1], 0.01, 1.0) # Note: sigma (state[0]) not used in friction matrix - structure-based G0_val = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") x = self.parameters.get_value("x") assert G0_val is not None assert tau0 is not None assert x is not None G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability # Effective modulus and relaxation rate G_eff = G0_val * G0_dim * lam gamma_relax = 1.0 / tau0 # Base relaxation rate # In SGR, the noise temperature x controls the yielding rate # Higher x means faster relaxation (more trap hopping) yielding_factor = np.exp(-1.0 / x) # Arrhenius-like activation # Friction components # M_11: Stress dissipation (viscous friction) M_11 = yielding_factor * gamma_relax * G_eff # M_22: Structural dissipation (trap dynamics) # Rate of structure change from trap hopping M_22 = yielding_factor * gamma_relax * lam * (1.0 - lam) # Cross-coupling (must maintain symmetry) # Stress can drive structural change and vice versa # Use geometric mean to ensure positive semi-definiteness # M_12 = alpha * sqrt(M_11 * M_22) with |alpha| <= 1 alpha = 0.0 # Decouple for simplicity (can be non-zero for coupled dynamics) M_12 = alpha * np.sqrt(M_11 * M_22 + 1e-20) # Symmetric friction matrix M = np.array([[M_11, M_12], [M_12, M_22]]) return M
# ========================================================================= # GENERIC Dynamics # =========================================================================
[docs] def reversible_dynamics(self, state: np.ndarray) -> np.ndarray: """Compute reversible (Hamiltonian) part of dynamics. dz/dt|_rev = L(z) * dF/dz This represents the energy-conserving part of the dynamics, encoding the reversible coupling between variables. Args: state: State vector [sigma, lambda] Returns: Time derivative dz/dt from reversible dynamics """ L = self.poisson_bracket(state) dF_dz = self.free_energy_gradient(state) return L @ dF_dz
[docs] def irreversible_dynamics(self, state: np.ndarray) -> np.ndarray: """Compute irreversible (dissipative) part of dynamics. dz/dt|_irrev = M(z) * dS/dz where dS/dz = (1/T) * dF/dz for systems at effective temperature T. This represents the entropy-producing part of the dynamics, encoding irreversible relaxation processes. Args: state: State vector [sigma, lambda] Returns: Time derivative dz/dt from irreversible dynamics """ M = self.friction_matrix(state) dF_dz = self.free_energy_gradient(state) # For non-equilibrium systems, dS/dz = dF/dz / T (with appropriate sign) # The irreversible dynamics drives the system toward equilibrium # dz/dt|_irrev = -M * dF/dz (negative gradient for energy minimization) return -M @ dF_dz
[docs] def full_dynamics(self, state: np.ndarray) -> np.ndarray: """Compute full GENERIC dynamics. dz/dt = L(z) * dF/dz + M(z) * dS/dz The total dynamics combines reversible (Hamiltonian) and irreversible (dissipative) contributions. Args: state: State vector [sigma, lambda] Returns: Total time derivative dz/dt """ dz_dt_rev = self.reversible_dynamics(state) dz_dt_irrev = self.irreversible_dynamics(state) return dz_dt_rev + dz_dt_irrev
# ========================================================================= # Thermodynamic Consistency Checks # =========================================================================
[docs] def entropy_production_rate(self, state: np.ndarray) -> float: """Compute entropy production rate dS/dt. This is equivalent to compute_entropy_production() but expressed in terms of entropy rather than free energy. Args: state: State vector [sigma, lambda] Returns: Entropy production rate dS/dt >= 0 """ T = self.parameters.get_value("x") # Noise temperature # dS/dt = W / T for dissipative processes at temperature T W = self.compute_entropy_production(state) assert T is not None return W / (T + 1e-20)
# ========================================================================= # BaseModel Interface Implementation # ========================================================================= def _fit( self, X: np.ndarray, y: np.ndarray, **kwargs, ) -> SGRGeneric: """Fit SGR GENERIC model to data using NLSQ optimization. Routes to appropriate fitting method based on test_mode. Args: X: Independent variable (frequency for oscillation, time for relaxation) y: Dependent variable (complex modulus, relaxation modulus, etc.) **kwargs: NLSQ optimizer arguments. Must include test_mode ('oscillation', 'relaxation', 'creep', 'steady_shear', 'laos'). Raises: ValueError: If test_mode not provided or invalid """ test_mode = kwargs.pop("test_mode", None) if test_mode is None: raise ValueError("test_mode must be specified for SGR GENERIC fitting") with log_fit(logger, model="SGRGeneric", data_shape=X.shape) as ctx: try: logger.info( "Starting SGR GENERIC model fit", test_mode=test_mode, n_points=len(X), ) logger.debug( "Input data statistics", x_range=(float(np.min(X)), float(np.max(X))), y_range=(float(np.min(np.abs(y))), float(np.max(np.abs(y)))), ) # Store test mode for mode-aware Bayesian inference self._test_mode = test_mode ctx["test_mode"] = test_mode # Route to appropriate fitting method if test_mode == "oscillation": self._fit_oscillation_mode(X, y, **kwargs) elif test_mode == "relaxation": self._fit_relaxation_mode(X, y, **kwargs) elif test_mode == "creep": self._fit_creep_mode(X, y, **kwargs) elif test_mode in ("steady_shear", "flow_curve"): self._fit_steady_shear_mode(X, y, **kwargs) elif test_mode == "laos": self._fit_laos_mode(X, y, **kwargs) elif test_mode == "startup": self._fit_startup_mode(X, y, **kwargs) else: raise ValueError( f"Unsupported test_mode: {test_mode}. " f"SGR GENERIC model supports 'oscillation', 'relaxation', " f"'creep', 'steady_shear', 'laos', 'startup'." ) # Log final parameters x_val = self.parameters.get_value("x") G0_val = self.parameters.get_value("G0") tau0_val = self.parameters.get_value("tau0") ctx["x"] = x_val ctx["G0"] = G0_val ctx["tau0"] = tau0_val ctx["phase_regime"] = self.get_phase_regime() logger.info( "SGR GENERIC model fit completed", x=x_val, G0=G0_val, tau0=tau0_val, phase_regime=self.get_phase_regime(), ) except Exception as e: logger.error( "SGR GENERIC model fit failed", test_mode=test_mode, error=str(e), exc_info=True, ) raise return self def _fit_oscillation_mode( self, omega: np.ndarray, G_star: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to complex modulus data (oscillation mode). Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0] to complex modulus data G*(omega). The GENERIC model uses the same kernel functions as SGRConventional in the linear viscoelastic regime. Args: omega: Angular frequency array (rad/s) G_star: Complex modulus data. Accepted formats: - Complex array (M,) where G* = G' + i*G'' - Real array (M, 2) where columns are [G', G''] **kwargs: NLSQ optimizer arguments Raises: RuntimeError: If optimization fails to converge """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Convert inputs to JAX arrays omega_jax = jnp.asarray(omega, dtype=jnp.float64) # Handle G_star format G_star_np = np.asarray(G_star) if np.iscomplexobj(G_star_np): G_star_2d = np.column_stack([np.real(G_star_np), np.imag(G_star_np)]) elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2: G_star_2d = G_star_np elif G_star_np.ndim == 2 and G_star_np.shape[0] == 2: G_star_2d = G_star_np.T else: raise ValueError( f"G_star must be complex (M,) or real (M, 2), got shape {G_star_np.shape}" ) G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64) # Create model function for NLSQ def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: x_param = params[0] G0_param = params[1] tau0_param = params[2] return self._predict_oscillation_jit(x_data, x_param, G0_param, tau0_param) # Create residual function objective = create_least_squares_objective( model_fn, omega_jax, G_star_jax, normalize=True, use_log_residuals=kwargs.get("use_log_residuals", False), ) # Run NLSQ optimization result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), max_iter=kwargs.get("max_iter", 1000), ftol=kwargs.get("ftol", 1e-6), xtol=kwargs.get("xtol", 1e-6), gtol=kwargs.get("gtol", 1e-6), ) if not result.success: raise RuntimeError( f"SGR GENERIC oscillation fitting failed: {result.message}. " "Try adjusting initial values or bounds." ) logger.debug( f"SGR GENERIC oscillation fit converged: x={self.parameters.get_value('x'):.4f}, " f"G0={self.parameters.get_value('G0'):.2e}, " f"tau0={self.parameters.get_value('tau0'):.2e}, " f"cost={result.fun:.3e}" ) self.fitted_ = True def _fit_relaxation_mode( self, t: np.ndarray, G_t: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to relaxation modulus data (relaxation mode). Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0] to relaxation modulus data G(t). Args: t: Time array (s) G_t: Relaxation modulus array (Pa) **kwargs: NLSQ optimizer arguments Raises: RuntimeError: If optimization fails to converge """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Convert inputs to JAX arrays t_jax = jnp.asarray(t, dtype=jnp.float64) G_t_jax = jnp.asarray(G_t, dtype=jnp.float64) # Create model function for NLSQ def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: x_param = params[0] G0_param = params[1] tau0_param = params[2] return self._predict_relaxation_jit(x_data, x_param, G0_param, tau0_param) # Create residual function (log-space for power-law data) objective = create_least_squares_objective( model_fn, t_jax, G_t_jax, normalize=True, use_log_residuals=kwargs.get("use_log_residuals", True), ) # Run NLSQ optimization result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), max_iter=kwargs.get("max_iter", 1000), ftol=kwargs.get("ftol", 1e-6), xtol=kwargs.get("xtol", 1e-6), gtol=kwargs.get("gtol", 1e-6), ) if not result.success: raise RuntimeError( f"SGR GENERIC relaxation fitting failed: {result.message}. " "Try adjusting initial values or bounds." ) logger.debug( f"SGR GENERIC relaxation fit converged: x={self.parameters.get_value('x'):.4f}, " f"G0={self.parameters.get_value('G0'):.2e}, " f"tau0={self.parameters.get_value('tau0'):.2e}, " f"cost={result.fun:.3e}" ) self.fitted_ = True def _fit_creep_mode( self, t: np.ndarray, J_t: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to creep compliance data (creep mode). Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0] to creep compliance data J(t). Theory: For x > 1 (fluid), J(t) ~ t^(x-1) Args: t: Time array (s) J_t: Creep compliance array (1/Pa) **kwargs: NLSQ optimizer arguments Raises: RuntimeError: If optimization fails to converge """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Convert inputs to JAX arrays t_jax = jnp.asarray(t, dtype=jnp.float64) J_t_jax = jnp.asarray(J_t, dtype=jnp.float64) # Create model function for NLSQ def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: x_param = params[0] G0_param = params[1] tau0_param = params[2] return self._predict_creep_jit(x_data, x_param, G0_param, tau0_param) # Create residual function (log-space for compliance spanning decades) objective = create_least_squares_objective( model_fn, t_jax, J_t_jax, normalize=True, use_log_residuals=kwargs.get("use_log_residuals", True), ) # Run NLSQ optimization result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), max_iter=kwargs.get("max_iter", 1000), ftol=kwargs.get("ftol", 1e-6), xtol=kwargs.get("xtol", 1e-6), gtol=kwargs.get("gtol", 1e-6), ) if not result.success: raise RuntimeError( f"SGR GENERIC creep fitting failed: {result.message}. " "Try adjusting initial values or bounds." ) logger.debug( f"SGR GENERIC creep fit converged: x={self.parameters.get_value('x'):.4f}, " f"G0={self.parameters.get_value('G0'):.2e}, " f"tau0={self.parameters.get_value('tau0'):.2e}, " f"cost={result.fun:.3e}" ) self.fitted_ = True def _fit_steady_shear_mode( self, gamma_dot: np.ndarray, sigma: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to steady shear flow curve data. Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0] to flow curve data sigma(gamma_dot). Theory: - Fluid (x > 1): sigma ~ gamma_dot^(x-1) - Glass (x < 1): sigma = sigma_y + A*gamma_dot^(1-x) Args: gamma_dot: Shear rate array (1/s) sigma: Stress array (Pa) **kwargs: NLSQ optimizer arguments Raises: RuntimeError: If optimization fails to converge """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Convert inputs to JAX arrays gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64) sigma_jax = jnp.asarray(sigma, dtype=jnp.float64) # Create model function for NLSQ def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: x_param = params[0] G0_param = params[1] tau0_param = params[2] return self._predict_steady_shear_jit(x_data, x_param, G0_param, tau0_param) # Create residual function (log-space for power-law data) objective = create_least_squares_objective( model_fn, gamma_dot_jax, sigma_jax, normalize=True, use_log_residuals=kwargs.get("use_log_residuals", True), ) # Run NLSQ optimization result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), max_iter=kwargs.get("max_iter", 1000), ftol=kwargs.get("ftol", 1e-6), xtol=kwargs.get("xtol", 1e-6), gtol=kwargs.get("gtol", 1e-6), ) if not result.success: raise RuntimeError( f"SGR GENERIC steady shear fitting failed: {result.message}. " "Try adjusting initial values or bounds." ) logger.debug( f"SGR GENERIC steady shear fit converged: x={self.parameters.get_value('x'):.4f}, " f"G0={self.parameters.get_value('G0'):.2e}, " f"tau0={self.parameters.get_value('tau0'):.2e}, " f"cost={result.fun:.3e}" ) self.fitted_ = True def _fit_startup_mode( self, t: np.ndarray, eta_plus: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to startup flow data (stress growth coefficient). Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0] to stress growth coefficient η⁺(t) data. Args: t: Time array (s) eta_plus: Stress growth coefficient array (Pa·s) **kwargs: NLSQ optimizer arguments, plus: - gamma_dot: Applied shear rate (required if y is stress) - is_stress: If True, treat y as stress """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma_dot = kwargs.get("gamma_dot", 1.0) is_stress = kwargs.get("is_stress", False) # Store gamma_dot for prediction self._startup_gamma_dot = gamma_dot if is_stress: eta_plus_data = eta_plus / gamma_dot else: eta_plus_data = eta_plus t_jax = jnp.asarray(t, dtype=jnp.float64) eta_plus_jax = jnp.asarray(eta_plus_data, dtype=jnp.float64) def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: x_param = params[0] G0_param = params[1] tau0_param = params[2] return self._predict_startup_jit( x_data, x_param, G0_param, tau0_param, gamma_dot ) objective = create_least_squares_objective( model_fn, t_jax, eta_plus_jax, normalize=True, use_log_residuals=kwargs.get("use_log_residuals", True), ) # Filter protocol kwargs before forwarding to NLSQ optimizer _SGR_RESERVED = { "gamma_dot", "is_stress", "use_log_residuals", "test_mode", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _SGR_RESERVED} result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs) if not result.success: raise RuntimeError(f"SGR GENERIC startup fitting failed: {result.message}") logger.debug( f"SGR GENERIC startup fit converged: x={self.parameters.get_value('x'):.4f}, " f"G0={self.parameters.get_value('G0'):.2e}, " f"tau0={self.parameters.get_value('tau0'):.2e}, " f"cost={result.fun:.3e}" ) self.fitted_ = True def _fit_laos_mode( self, t: np.ndarray, sigma: np.ndarray, **kwargs, ) -> None: """Fit SGR GENERIC to LAOS stress data. Uses Monte Carlo or Population Balance solver for time-domain stress prediction, then optimizes parameters to match measured stress. Args: t: Time array (s) sigma: Stress array (Pa) **kwargs: Required kwargs: - gamma_0: Strain amplitude - omega: Angular frequency (rad/s) Optional kwargs: - n_particles: Monte Carlo particle count (default 5000) - use_pde: Use PDE solver instead of MC (default False) Raises: ValueError: If gamma_0 or omega not provided RuntimeError: If optimization fails """ gamma_0 = kwargs.get("gamma_0") omega = kwargs.get("omega") if gamma_0 is None or omega is None: raise ValueError("LAOS fitting requires gamma_0 and omega in kwargs") if gamma_0 <= 0: raise ValueError(f"gamma_0 must be positive, got {gamma_0}") n_particles = kwargs.get("n_particles", 5000) use_pde = kwargs.get("use_pde", False) logger.info( f"SGR GENERIC LAOS fitting: gamma_0={gamma_0}, omega={omega}, " f"{'PDE' if use_pde else 'MC'} solver with {n_particles if not use_pde else 'grid'}" ) # Store LAOS parameters self._gamma_0 = gamma_0 self._omega_laos = omega # For now, use analytical approximation for small amplitude # Full MC/PDE fitting would require iterative simulation if gamma_0 < 0.1: # Small amplitude - use SAOS approximation logger.warning( f"Small strain amplitude gamma_0={gamma_0}. Using SAOS approximation." ) # Use JAX-native FFT for JAX-First compliance sigma_fft = jnp.fft.fft(jnp.asarray(sigma)) n = len(sigma) fundamental_idx = int(omega * (t[-1] - t[0]) / (2 * np.pi)) fundamental_idx = max(1, min(fundamental_idx, n // 2 - 1)) G_star_amplitude = ( 2.0 * float(jnp.abs(sigma_fft[fundamental_idx])) / (n * gamma_0) ) phase = float(jnp.angle(sigma_fft[fundamental_idx])) G_prime = G_star_amplitude * np.cos(phase) G_double_prime = G_star_amplitude * np.sin(phase) # Fit to single-point SAOS omega_single = np.array([omega]) G_star_single = np.array([[G_prime, G_double_prime]]) self._fit_oscillation_mode(omega_single, G_star_single, **kwargs) else: # Large amplitude - full MC-based LAOS fitting self._fit_laos_mc(t, sigma, gamma_0, omega, n_particles, **kwargs) def _fit_laos_mc( self, t: np.ndarray, sigma: np.ndarray, gamma_0: float, omega: float, n_particles: int, **kwargs, ) -> None: """Full Monte Carlo-based LAOS fitting. Runs MC simulations within optimization loop to match time-domain stress. Args: t: Time array (s) sigma: Measured stress array (Pa) gamma_0: Strain amplitude omega: Angular frequency (rad/s) n_particles: Number of MC particles **kwargs: Optimizer arguments Note: Uses scipy.optimize.minimize (L-BFGS-B) because the objective function calls Monte Carlo simulations which are stochastic and not JAX-traceable. This is acceptable per Technical Guidelines as it's used only for large- amplitude LAOS fitting, not the primary oscillation/relaxation modes. """ from scipy.optimize import minimize from rheojax.utils.sgr_monte_carlo import simulate_oscillatory logger.info( f"Full MC-based LAOS fitting: {n_particles} particles, " f"gamma_0={gamma_0}, omega={omega:.3f} rad/s" ) # Determine simulation parameters from data period = 2.0 * np.pi / omega t_total = t[-1] - t[0] n_cycles = max(1, int(t_total / period)) points_per_cycle = max(10, len(t) // n_cycles) # Warm-start: estimate parameters from stress amplitude sigma_max = np.max(np.abs(sigma)) G0_init = sigma_max / gamma_0 x_init = self.parameters.get_value("x") tau0_init = self.parameters.get_value("tau0") assert x_init is not None assert tau0_init is not None # Normalize target stress for residual calculation sigma_norm = sigma / (sigma_max + 1e-12) # Fixed random seed for reproducibility within optimization seed = kwargs.get("seed", 42) def objective(params): """Compute residual between MC stress and measured stress.""" x_val, log_G0, log_tau0 = params G0_val = np.exp(log_G0) tau0_val = np.exp(log_tau0) # Clamp x to valid range x_val = np.clip(x_val, 0.5, 2.5) try: # Run MC simulation key = jax.random.PRNGKey(seed) _, _, sigma_mc = simulate_oscillatory( key=key, gamma_0=gamma_0, omega=omega, n_cycles=n_cycles, points_per_cycle=points_per_cycle, x=x_val, n_particles=n_particles, k=G0_val, Gamma0=1.0 / tau0_val, xg=1.0, ) # Interpolate to match data time points t_mc = np.linspace(0, t_total, len(sigma_mc)) sigma_mc_interp = np.interp(t - t[0], t_mc, np.array(sigma_mc)) # Normalize MC stress sigma_mc_max = np.max(np.abs(sigma_mc_interp)) + 1e-12 sigma_mc_norm = sigma_mc_interp / sigma_mc_max # Compute residual (allow phase shift by minimizing over shifts) residual = np.sum((sigma_mc_norm - sigma_norm) ** 2) return residual except Exception as e: logger.warning(f"MC simulation failed: {e}") return 1e10 # Large penalty # Initial guess in log space for G0, tau0 x0 = np.array([x_init, np.log(G0_init), np.log(tau0_init)]) # Bounds bounds = [ (0.5, 2.5), # x (np.log(1e-3), np.log(1e9)), # log(G0) (np.log(1e-9), np.log(1e3)), # log(tau0) ] # Run optimization max_iter = kwargs.get("max_iter", 50) logger.info(f"Starting MC-LAOS optimization (max {max_iter} iterations)...") result = minimize( objective, x0, method="L-BFGS-B", bounds=bounds, options={"maxiter": max_iter, "disp": False}, ) # Update parameters x_opt, log_G0_opt, log_tau0_opt = result.x self.parameters.set_value("x", float(x_opt)) self.parameters.set_value("G0", float(np.exp(log_G0_opt))) self.parameters.set_value("tau0", float(np.exp(log_tau0_opt))) if result.success: logger.info( f"MC-LAOS fit converged: x={x_opt:.4f}, " f"G0={np.exp(log_G0_opt):.2e}, tau0={np.exp(log_tau0_opt):.2e}, " f"cost={result.fun:.3e}" ) else: logger.warning( f"MC-LAOS fit did not fully converge: {result.message}. " f"Best: x={x_opt:.4f}, G0={np.exp(log_G0_opt):.2e}" ) self.fitted_ = True @staticmethod @jax.jit def _predict_creep_jit( t: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled creep prediction: J(t). Theory: J(t) ~ t^(x-1) for x > 1 (fluid regime) Args: t: Time array (s) x: Effective noise temperature (dimensionless) G0_scale: Modulus scale (Pa) tau0: Attempt time (s) Returns: Creep compliance J(t) with shape (M,) """ # Dimensionless time t_scaled = t / tau0 # Compute equilibrium modulus factor G0_dim = G0(x) epsilon = 1e-12 t_safe = jnp.maximum(t_scaled, epsilon) # Creep compliance: J(t) ~ (1 + t/tau0)^(x-1) / G0 # This is the inverse relationship to G(t) growth_exp = x - 1.0 J_t = jnp.power(1.0 + t_safe, growth_exp) / (G0_scale * G0_dim) # Monotonicity enforced by physical parameter bounds, not in NUTS path return J_t @staticmethod @jax.jit def _predict_steady_shear_jit( gamma_dot: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled steady shear prediction: sigma(gamma_dot). Theory: - Fluid (x > 1): sigma ~ gamma_dot^(x-1) - Glass (x < 1): sigma = sigma_y + A*gamma_dot^(1-x) Args: gamma_dot: Shear rate array (1/s) x: Effective noise temperature (dimensionless) G0_scale: Modulus scale (Pa) tau0: Attempt time (s) Returns: Stress sigma(gamma_dot) with shape (M,) """ # Compute equilibrium modulus factor G0_dim = G0(x) epsilon = 1e-12 gamma_dot_safe = jnp.maximum(gamma_dot, epsilon) # Dimensionless shear rate gamma_dot_scaled = gamma_dot_safe * tau0 # Flow curve: sigma = G0 * tau0 * gamma_dot * (gamma_dot * tau0)^(x-2) # = G0 * (gamma_dot * tau0)^(x-1) sigma = G0_scale * G0_dim * jnp.power(gamma_dot_scaled, x - 1.0) return sigma @staticmethod @jax.jit def _predict_startup_jit( t: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, gamma_dot: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled startup flow prediction: eta_plus(t). Computes stress growth coefficient η⁺(t) = σ(t)/γ̇ = ∫₀ᵗ G(s) ds. Same analytical form as SGRConventional for linear viscoelastic envelope. """ from rheojax.utils.sgr_kernels import G0 as G0_func # Dimensionless time t_scaled = t / tau0 # Compute equilibrium modulus factor G0_dim = G0_func(x) epsilon = 1e-12 t_safe = jnp.maximum(t_scaled, epsilon) # eta_plus = INT_0^t G ds with G ~ (1+s/tau0)^(1-x) gives exponent (2-x): # 1<x<2 grows (no finite zero-shear viscosity), x>2 saturates to # eta_0 = G0*G0(x)*tau0/(x-2) analytically. (Earlier versions used x-1 # plus a steady-state clamp; that clamp encoded the old, incorrect # exponent and would now wrongly cap the growing 1<x<2 regime, so it is # removed — kept in lockstep with SGRConventional.) exp = 2.0 - x def exp_near_zero(_): # INT (1+s/tau0)^(-1) ds = tau0 * ln(1 + t/tau0) (x = 2) return G0_scale * G0_dim * tau0 * jnp.log(1.0 + t_safe) def exp_nonzero(_): # [(1+t/tau0)^(2-x) - 1] / (2-x) return ( G0_scale * G0_dim * tau0 * ((jnp.power(1.0 + t_safe, exp) - 1.0) / exp) ) eta_plus = jax.lax.cond( jnp.abs(exp) < 1e-6, exp_near_zero, exp_nonzero, operand=None, ) return eta_plus @staticmethod @jax.jit def _predict_oscillation_jit( omega: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled oscillation prediction: G'(omega), G''(omega). Uses same kernel functions as SGRConventional for linear response. The GENERIC formulation gives equivalent results in the linear regime. Args: omega: Angular frequency array (rad/s) x: Effective noise temperature (dimensionless) G0_scale: Modulus scale (Pa) tau0: Attempt time (s) Returns: Complex modulus [G', G''] with shape (M, 2) """ # Compute dimensionless frequency omega_tau0 = omega * tau0 # Call Gp kernel (returns G_prime, G_double_prime) G_prime, G_double_prime = Gp(x, omega_tau0) # Scale by G0 G_prime_scaled = G0_scale * G_prime G_double_prime_scaled = G0_scale * G_double_prime # Stack into (M, 2) array G_star = jnp.stack([G_prime_scaled, G_double_prime_scaled], axis=1) return G_star @staticmethod @jax.jit def _predict_relaxation_jit( t: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled relaxation prediction: G(t). Uses power-law form consistent with SGR theory. Args: t: Time array (s) x: Effective noise temperature (dimensionless) G0_scale: Modulus scale (Pa) tau0: Attempt time (s) Returns: Relaxation modulus G(t) with shape (M,) """ # Dimensionless time t_scaled = t / tau0 # Compute equilibrium modulus factor (dimensionless) G0_dim = G0(x) epsilon = 1e-12 t_safe = jnp.maximum(t_scaled, epsilon) # Power-law form: G(t) ~ (1 + t/tau0)^(1-x), i.e. G(t) ~ t^(1-x) for # 1 < x < 2 — the theoretical SGR relaxation exponent (Fourier- # consistent with SAOS G', G'' ~ omega^(x-1); negative of the creep # exponent x-1). Kept in lockstep with SGRConventional so the # relaxation parity test holds at all x. (Old form used x-2.) G_t = G0_scale * G0_dim / jnp.power(1.0 + t_safe, x - 1.0) return G_t @staticmethod @jax.jit def _predict_viscosity_jit( gamma_dot: jnp.ndarray, x: jax.Array | float, G0_scale: jax.Array | float, tau0: jax.Array | float, ) -> jnp.ndarray: """JIT-compiled viscosity prediction: eta(gamma_dot). Computes viscosity as function of shear rate: eta ~ gamma_dot^(x-2) for 1 < x < 2 (shear-thinning) eta = const for x >= 2 (Newtonian) sigma_y > 0 for x < 1 (yield stress, glass phase) Args: gamma_dot: Shear rate array (1/s) x: Effective noise temperature (dimensionless) G0_scale: Modulus scale (Pa) tau0: Attempt time (s) Returns: Viscosity eta(gamma_dot) with shape (M,) Notes: - Shear-thinning exponent: x - 2 - Uses relationship: eta ~ G0 * tau0 * (gamma_dot * tau0)^(x-2) """ # Dimensionless shear rate gamma_dot_scaled = gamma_dot * tau0 epsilon = 1e-12 gamma_dot_safe = jnp.maximum(gamma_dot_scaled, epsilon) # Compute equilibrium modulus factor G0_dim = G0(x) # Viscosity power-law exponent visc_exp = x - 2.0 # Viscosity formula # eta = G0_scale * tau0 * G0_dim * (gamma_dot * tau0)^(x-2) # For x = 2: eta = const (Newtonian) # For x < 2: eta decreases with gamma_dot (shear-thinning) eta = G0_scale * tau0 * G0_dim * jnp.power(gamma_dot_safe, visc_exp) return eta def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Predict based on fitted test mode. Routes to appropriate prediction method based on stored test_mode. Args: X: Independent variable (frequency or time) **kwargs: Additional arguments including optional test_mode override Returns: Predicted values (complex modulus, relaxation modulus, or viscosity) Raises: ValueError: If test_mode not set (model not fitted) """ # Get test_mode from kwargs or instance attribute. # R10-SGR-001: use explicit None check instead of `or` to avoid swallowing # falsy-but-valid test_mode strings (e.g. an empty string, though unlikely). test_mode = kwargs.get("test_mode") if test_mode is None: test_mode = getattr(self, "_test_mode", None) if test_mode is None: raise ValueError("test_mode must be specified for prediction") if test_mode == "oscillation": return self._predict_oscillation(X) elif test_mode == "relaxation": return self._predict_relaxation(X) elif test_mode in ("steady_shear", "flow_curve"): return self._predict_steady_shear(X) elif test_mode == "creep": return self._predict_creep(X) elif test_mode == "startup": return self._predict_startup(X) elif test_mode in ("laos", "oscillation_laos"): # R8-SGR-001: wire LAOS protocol to simulate_laos() gamma_0 = kwargs.get("gamma_0", getattr(self, "_gamma_0", 0.1)) omega = kwargs.get( "omega_laos", kwargs.get("omega", getattr(self, "_omega_laos", 1.0)) ) n_cycles = kwargs.get("n_cycles", getattr(self, "_n_cycles", 2)) n_points_per_cycle = 256 _strain, stress = self.simulate_laos( gamma_0=gamma_0, omega=omega, n_cycles=n_cycles, n_points_per_cycle=n_points_per_cycle, ) # simulate_laos returns (strain, stress), not (time, stress). # Always interpolate to the user's X grid — a length coincidence # does not guarantee the internal and user grids align. period = 2.0 * np.pi / omega t_internal = np.linspace( 0, n_cycles * period, n_cycles * n_points_per_cycle, endpoint=False ) X_arr = np.asarray(X) stress_arr = np.asarray(stress) return np.interp(X_arr, t_internal, stress_arr) else: raise ValueError(f"Unknown test_mode: {test_mode}") def _predict_oscillation(self, omega: np.ndarray) -> np.ndarray: """Predict complex modulus in oscillation mode. Args: omega: Angular frequency array (rad/s) Returns: Complex modulus G* = G' + iG'' with shape (M,) """ x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") omega_jax = jnp.asarray(omega) G_star_jax = self._predict_oscillation_jit(omega_jax, x, G0_scale, tau0) # Convert (N,2) [G', G''] to complex G* for consistent API result = np.array(G_star_jax) return result[:, 0] + 1j * result[:, 1] def _predict_relaxation(self, t: np.ndarray) -> np.ndarray: """Predict relaxation modulus in relaxation mode. Args: t: Time array (s) Returns: Relaxation modulus array (Pa) """ x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") t_jax = jnp.asarray(t) G_t_jax = self._predict_relaxation_jit(t_jax, x, G0_scale, tau0) return np.array(G_t_jax) def _predict_steady_shear(self, gamma_dot: np.ndarray) -> np.ndarray: """Predict stress in steady shear mode. Args: gamma_dot: Shear rate array (1/s) Returns: Stress array (Pa) — consistent with the fit path (_fit_steady_shear_mode). """ x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") gamma_dot_jax = jnp.asarray(gamma_dot) # R10-SGR-002: use _predict_steady_shear_jit (stress output) so that fit # and predict are consistent. _predict_viscosity_jit returns eta, not sigma. sigma_jax = self._predict_steady_shear_jit(gamma_dot_jax, x, G0_scale, tau0) return np.array(sigma_jax) def _predict_creep(self, t: np.ndarray) -> np.ndarray: """Predict creep compliance J(t).""" x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") t_jax = jnp.asarray(t) J_t_jax = self._predict_creep_jit(t_jax, x, G0_scale, tau0) return np.array(J_t_jax) def _predict_startup(self, t: np.ndarray) -> np.ndarray: """Predict startup stress growth coefficient eta_plus(t).""" x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") gamma_dot = getattr(self, "_startup_gamma_dot", None) if gamma_dot is None: raise RuntimeError( "SGRGeneric._predict_startup requires _startup_gamma_dot. " "Call fit() with test_mode='startup' first." ) t_jax = jnp.asarray(t) eta_plus_jax = self._predict_startup_jit(t_jax, x, G0_scale, tau0, gamma_dot) return np.array(eta_plus_jax)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference with NumPyro NUTS. Required by BayesianMixin for NumPyro NUTS sampling. Args: X: Independent variable (frequency or time) params: Array of parameter values [x, G0, tau0] test_mode: Optional test mode override **kwargs: Protocol-specific arguments (gamma_dot, sigma_applied, etc.) Returns: Model predictions as JAX array """ x = params[0] G0_scale = params[1] tau0 = params[2] mode = test_mode if test_mode is not None else self._test_mode if mode is None: mode = "oscillation" X_jax = jnp.asarray(X) if mode == "oscillation": return self._predict_oscillation_jit(X_jax, x, G0_scale, tau0) elif mode == "relaxation": return self._predict_relaxation_jit(X_jax, x, G0_scale, tau0) elif mode in ("steady_shear", "flow_curve"): return self._predict_steady_shear_jit(X_jax, x, G0_scale, tau0) elif mode == "creep": return self._predict_creep_jit(X_jax, x, G0_scale, tau0) elif mode == "startup": # Priority: explicit kwarg > _last_fit_kwargs > instance attr # Use None sentinel (not `or`) to avoid swallowing gamma_dot=0.0. gamma_dot = kwargs.get("gamma_dot") if gamma_dot is None: last_kwargs = getattr(self, "_last_fit_kwargs", None) or {} gamma_dot = last_kwargs.get("gamma_dot") if gamma_dot is None: gamma_dot = getattr(self, "_startup_gamma_dot", None) if gamma_dot is None: # R-SGR-GENERIC-001: Require explicit gamma_dot — silent 1.0 default # masks bugs during NUTS startup inference. raise RuntimeError( "SGRGeneric.model_function: gamma_dot not provided and " "_startup_gamma_dot not cached. Call fit() with startup data first." ) return self._predict_startup_jit(X_jax, x, G0_scale, tau0, gamma_dot) elif mode in ("laos", "oscillation_laos"): # R8-SGR-001: LAOS not supported in NUTS (OOM for Bayesian), raise informative error raise NotImplementedError( "LAOS mode is not supported in model_function for Bayesian inference. " "Use _predict() / predict() directly after fitting." ) else: raise ValueError(f"Unsupported test mode: {mode}")
[docs] def get_phase_regime(self) -> str: """Determine material phase regime from noise temperature x. Returns: Phase regime string: 'glass', 'power-law', or 'newtonian' """ x = self.parameters.get_value("x") assert x is not None if x < 1.0: return "glass" elif x < 2.0: return "power-law" else: return "newtonian"
# ========================================================================= # Dynamic x Parameter Initialization # ========================================================================= def _init_dynamic_x_parameters(self) -> None: """Initialize parameters for dynamic noise temperature evolution. Adds parameters for aging/rejuvenation kinetics: - x_eq: Equilibrium noise temperature at rest - alpha_aging: Aging rate coefficient - beta_rejuv: Rejuvenation rate coefficient - x_ss_A: Steady-state amplitude - x_ss_n: Steady-state power-law exponent """ # x_eq: Equilibrium noise temperature at rest if "x_eq" not in self.parameters.keys(): self.parameters.add( name="x_eq", value=1.0, bounds=(0.5, 2.5), units="dimensionless", description="Equilibrium noise temperature at rest", ) # alpha_aging: Aging rate coefficient if "alpha_aging" not in self.parameters.keys(): self.parameters.add( name="alpha_aging", value=0.1, bounds=(0.0, 10.0), units="1/s", description="Aging rate coefficient", ) # beta_rejuv: Rejuvenation rate coefficient if "beta_rejuv" not in self.parameters.keys(): self.parameters.add( name="beta_rejuv", value=0.5, bounds=(0.0, 10.0), units="s", description="Rejuvenation rate coefficient", ) # x_ss_A: Steady-state amplitude if "x_ss_A" not in self.parameters.keys(): self.parameters.add( name="x_ss_A", value=0.5, bounds=(0.0, 2.0), units="dimensionless", description="Steady-state amplitude factor", ) # x_ss_n: Steady-state power-law exponent if "x_ss_n" not in self.parameters.keys(): self.parameters.add( name="x_ss_n", value=0.3, bounds=(0.0, 1.0), units="dimensionless", description="Steady-state power-law exponent", ) # ========================================================================= # Thixotropy Methods (User Story 3) # =========================================================================
[docs] def enable_thixotropy( self, k_build: float = 0.1, k_break: float = 0.5, n_struct: float = 2.0, ) -> None: """Enable thixotropy modeling with structural parameter lambda(t). Adds thixotropy kinetics parameters to the model. The structural parameter lambda represents the state of internal microstructure: - lambda = 1: Fully built structure - lambda = 0: Fully broken structure Evolution equation: d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda The effective modulus is coupled to lambda: G_eff(t) = G0 * lambda(t)^n_struct Args: k_build: Structure build-up rate (1/s), default 0.1 k_break: Structure breakdown rate (dimensionless), default 0.5 n_struct: Structural coupling exponent, default 2.0 Example: >>> model = SGRGeneric() >>> model.enable_thixotropy(k_build=0.1, k_break=0.5, n_struct=2.0) >>> # Now model can predict stress transients with thixotropy """ # Add thixotropy parameters if not already present if "k_build" not in self.parameters.keys(): self.parameters.add( name="k_build", value=k_build, bounds=(0.0, 10.0), units="1/s", description="Structure build-up rate (1/s)", ) else: self.parameters.set_value("k_build", k_build) if "k_break" not in self.parameters.keys(): self.parameters.add( name="k_break", value=k_break, bounds=(0.0, 10.0), units="dimensionless", description="Structure breakdown rate (shear-dependent)", ) else: self.parameters.set_value("k_break", k_break) if "n_struct" not in self.parameters.keys(): self.parameters.add( name="n_struct", value=n_struct, bounds=(0.1, 5.0), units="dimensionless", description="Structural coupling exponent", ) else: self.parameters.set_value("n_struct", n_struct) # Flag for thixotropy mode self._thixotropy_enabled = True
@staticmethod @jax.jit def _evolve_lambda_jit( t_jax: jnp.ndarray, gamma_dot_abs: jnp.ndarray, lambda_initial: float, k_build: float, k_break: float, ) -> jnp.ndarray: # Compute dt array dt = jnp.diff(t_jax, prepend=t_jax[0]) # lax.scan step function def step(lambda_prev, inputs): dt_i, gdot_i = inputs # dy/dt = A - B*y with exponential integrator A = k_build B = k_build + k_break * gdot_i # Exact exponential integration for B > 0; linear for B ≈ 0 lambda_ss = A / jnp.maximum(B, 1e-30) decay = jnp.exp(-B * dt_i) lambda_exp = lambda_ss + (lambda_prev - lambda_ss) * decay lambda_lin = lambda_prev + A * dt_i lambda_new = jnp.where(B > 1e-12, lambda_exp, lambda_lin) lambda_new = jnp.clip(lambda_new, 0.0, 1.0) return lambda_new, lambda_new # Scan over time steps (skip first step where dt=0) _, lambda_steps = jax.lax.scan( step, jnp.float64(lambda_initial), (dt[1:], gamma_dot_abs[1:]) ) return jnp.concatenate([jnp.array([lambda_initial]), lambda_steps])
[docs] def evolve_lambda( self, t: np.ndarray, gamma_dot: np.ndarray, lambda_initial: float = 1.0, ) -> np.ndarray: """Evolve structural parameter lambda(t) for given shear history. Integrates the thixotropy kinetics equation: d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda Uses JAX lax.scan for vectorized time-stepping (replaces Python for-loop). Args: t: Time array (s) gamma_dot: Shear rate array (1/s), same shape as t lambda_initial: Initial structural parameter [0, 1], default 1.0 Returns: lambda_t: Structural parameter evolution, same shape as t Raises: ValueError: If thixotropy not enabled or array shapes mismatch Example: >>> model = SGRGeneric() >>> model.enable_thixotropy() >>> t = np.linspace(0, 10, 100) >>> gamma_dot = np.ones_like(t) * 10.0 # Constant shear >>> lambda_t = model.evolve_lambda(t, gamma_dot, lambda_initial=1.0) """ if not self._thixotropy_enabled: raise ValueError("Thixotropy not enabled. Call enable_thixotropy() first.") if t.shape != gamma_dot.shape: raise ValueError( f"Time and shear rate arrays must have same shape: " f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}" ) # Get thixotropy parameters k_build = self.parameters.get_value("k_build") k_break = self.parameters.get_value("k_break") assert k_build is not None and k_break is not None t_jax = jnp.asarray(t) gamma_dot_abs = jnp.abs(jnp.asarray(gamma_dot)) # Call JIT-compiled scanner lambda_t_jax = self._evolve_lambda_jit( t_jax, gamma_dot_abs, lambda_initial, k_build, k_break ) lambda_t = np.asarray(lambda_t_jax) self._lambda_trajectory = lambda_t return lambda_t
[docs] def predict_thixotropic_stress( self, t: np.ndarray, gamma_dot: np.ndarray, lambda_t: np.ndarray | None = None, lambda_initial: float = 1.0, ) -> np.ndarray: """Predict stress response with thixotropic modulus. The effective modulus is coupled to the structural parameter: G_eff(t) = G0 * lambda(t)^n_struct Args: t: Time array (s) gamma_dot: Shear rate array (1/s) lambda_t: Pre-computed lambda trajectory, or None to compute lambda_initial: Initial lambda if computing [0, 1], default 1.0 Returns: sigma: Stress response (Pa) Example: >>> model = SGRGeneric() >>> model.enable_thixotropy() >>> t = np.linspace(0, 10, 100) >>> gamma_dot = np.ones_like(t) * 10.0 >>> sigma = model.predict_thixotropic_stress(t, gamma_dot) """ if not self._thixotropy_enabled: raise ValueError("Thixotropy not enabled. Call enable_thixotropy() first.") # Compute lambda trajectory if not provided if lambda_t is None: lambda_t = self.evolve_lambda(t, gamma_dot, lambda_initial) # Get parameters G0_val = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") x = self.parameters.get_value("x") n_struct = self.parameters.get_value("n_struct") assert G0_val is not None assert tau0 is not None assert x is not None assert n_struct is not None # Effective modulus from structure G_eff = G0_val * np.power(lambda_t, n_struct) # Viscosity from power-law (SGR-like) gamma_dot_safe = np.maximum(np.abs(gamma_dot), 1e-12) eta_factor = np.power(gamma_dot_safe * tau0, x - 2.0) # Stress = G_eff * gamma_dot * tau0 * eta_factor sigma = G_eff * gamma_dot * tau0 * eta_factor return sigma
[docs] def predict_stress_transient( self, t: np.ndarray, gamma_dot: np.ndarray, lambda_initial: float = 1.0, ) -> tuple[np.ndarray, np.ndarray]: """Predict stress transient (overshoot/undershoot) for shear step protocol. For step-up in shear rate: Initially high stress (intact structure) followed by decay as structure breaks down (overshoot). For step-down in shear rate: Initially low stress (broken structure) followed by increase as structure rebuilds (undershoot). Args: t: Time array (s) gamma_dot: Shear rate array (1/s), can include steps lambda_initial: Initial structural parameter [0, 1] Returns: sigma: Stress response (Pa) lambda_t: Structural parameter evolution Example: >>> model = SGRGeneric() >>> model.enable_thixotropy() >>> t = np.linspace(0, 10, 100) >>> gamma_dot = np.ones_like(t) >>> gamma_dot[t >= 5] = 10.0 # Step up at t=5 >>> sigma, lambda_t = model.predict_stress_transient(t, gamma_dot) """ # Evolve lambda lambda_t = self.evolve_lambda(t, gamma_dot, lambda_initial) # Compute stress sigma = self.predict_thixotropic_stress(t, gamma_dot, lambda_t) return sigma, lambda_t
# ========================================================================= # Shear Banding Detection Methods (User Story 1) # =========================================================================
[docs] def detect_shear_banding( self, gamma_dot: np.ndarray | None = None, sigma: np.ndarray | None = None, n_points: int = 100, gamma_dot_range: tuple[float, float] = (1e-2, 1e2), ) -> tuple[bool, dict | None]: """Detect shear banding from constitutive curve. Computes the steady-state flow curve and checks for non-monotonicity (d sigma / d gamma_dot < 0) which indicates shear banding instability. Args: gamma_dot: Shear rate array (1/s). If None, uses gamma_dot_range. sigma: Stress array (Pa). If None, computes from model. n_points: Number of points if computing flow curve gamma_dot_range: Range for computing flow curve if gamma_dot is None Returns: is_banding: True if shear banding detected banding_info: Dict with banding region info, or None Example: >>> model = SGRGeneric() >>> model.parameters.set_value("x", 0.8) # Glass regime >>> is_banding, info = model.detect_shear_banding() """ # Import detection function from rheojax.transforms.srfs import detect_shear_banding as _detect_banding # Compute flow curve if not provided if gamma_dot is None: gamma_dot = np.logspace( np.log10(gamma_dot_range[0]), np.log10(gamma_dot_range[1]), n_points, ) if sigma is None: # Compute stress directly from model (predict returns sigma for steady_shear). # R10-SGR-002: _predict_steady_shear now returns sigma (stress), not eta. self._test_mode = "steady_shear" sigma = self.predict(gamma_dot) # Detect shear banding is_banding, banding_info = _detect_banding(gamma_dot, sigma, warn=True) return is_banding, banding_info
[docs] def predict_banded_flow( self, gamma_dot_applied: float, gamma_dot: np.ndarray | None = None, sigma: np.ndarray | None = None, n_points: int = 100, ) -> dict | None: """Predict flow in shear banding regime with lever rule. When shear banding occurs, the material splits into bands with different local shear rates. This method computes the band fractions and the composite stress. Args: gamma_dot_applied: Applied average shear rate (1/s) gamma_dot: Shear rate array for flow curve. If None, computed. sigma: Stress array for flow curve. If None, computed. n_points: Number of points if computing flow curve Returns: coexistence: Dict with band coexistence info, or None Example: >>> model = SGRGeneric() >>> model.parameters.set_value("x", 0.8) >>> coex = model.predict_banded_flow(gamma_dot_applied=1.0) >>> if coex: ... print(f"Low band: {coex['fraction_low']:.2%}") ... print(f"High band: {coex['fraction_high']:.2%}") """ from rheojax.transforms.srfs import compute_shear_band_coexistence # Compute flow curve if not provided if gamma_dot is None: gamma_dot = np.logspace(-2, 3, n_points) if sigma is None: # R10-SGR-002: _predict_steady_shear now returns sigma (stress), not eta. self._test_mode = "steady_shear" sigma = self.predict(gamma_dot) # Compute coexistence coexistence = compute_shear_band_coexistence( gamma_dot, sigma, gamma_dot_applied ) return coexistence
# ========================================================================= # LAOS Analysis Methods (User Story 2) # =========================================================================
[docs] def simulate_laos( self, gamma_0: float, omega: float, n_cycles: int = 2, n_points_per_cycle: int = 256, ) -> tuple[np.ndarray, np.ndarray]: """Simulate LAOS response for given strain amplitude and frequency. Generates time-domain stress response to sinusoidal strain input: gamma(t) = gamma_0 * sin(omega * t) For SGR model, the stress response is computed using the complex modulus in the linear viscoelastic approximation, with nonlinearity arising from strain-dependent softening at large amplitudes. Args: gamma_0: Strain amplitude (dimensionless) omega: Angular frequency (rad/s) n_cycles: Number of oscillation cycles to simulate n_points_per_cycle: Number of time points per cycle Returns: strain: Strain array gamma(t) stress: Stress array sigma(t) Example: >>> model = SGRGeneric() >>> model.parameters.set_value("x", 1.5) >>> strain, stress = model.simulate_laos(gamma_0=0.1, omega=1.0) """ # Store LAOS parameters self._gamma_0 = gamma_0 self._omega_laos = omega # Get model parameters x = self.parameters.get_value("x") G0_scale = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") # Time array period = 2.0 * np.pi / omega t_max = n_cycles * period n_points = n_cycles * n_points_per_cycle t = np.linspace(0, t_max, n_points, endpoint=False) # Strain: gamma(t) = gamma_0 * sin(omega * t) strain = gamma_0 * np.sin(omega * t) # Strain rate: gamma_dot(t) = gamma_0 * omega * cos(omega * t) strain_rate = gamma_0 * omega * np.cos(omega * t) # Get complex modulus at this frequency omega_arr = np.array([omega]) G_star = self._predict_oscillation_jit( jnp.asarray(omega_arr), x, G0_scale, tau0 ) G_prime = float(G_star[0, 0]) G_double_prime = float(G_star[0, 1]) # In linear viscoelastic regime: # sigma(t) = G' * gamma(t) + (G'' / omega) * gamma_dot(t) # sigma(t) = G' * gamma_0 * sin(omega*t) + G'' * gamma_0 * cos(omega*t) stress = G_prime * strain + (G_double_prime / omega) * strain_rate # Add weak nonlinearity based on SGR physics. # R10-SGR-003: apply softening at ALL amplitudes — the SGR model is inherently # nonlinear and produces a non-zero third harmonic even for gamma_0 <= 0.1. # The old gamma_0 > 0.1 gate suppressed nonlinearity entirely in the small- # amplitude regime, yielding zero third harmonic (physically wrong for SGR). softening = 1.0 - 0.1 * (np.abs(strain) / max(gamma_0, 1e-10)) ** 2 stress = stress * softening return strain, stress
[docs] def extract_laos_harmonics( self, stress: np.ndarray, n_points_per_cycle: int = 256, ) -> dict: """Extract Fourier harmonics from LAOS stress response. Performs FFT analysis to extract harmonic amplitudes and phases: sigma(t) = sum_n I_n * sin(n*omega*t + phi_n) For LAOS, odd harmonics (n = 1, 3, 5, ...) dominate due to symmetry. Args: stress: Stress time series from simulate_laos() n_points_per_cycle: Points per oscillation cycle Returns: Dictionary containing: - I_1, I_3, I_5, ...: Harmonic amplitudes - phi_1, phi_3, phi_5, ...: Phase angles - I_3_I_1, I_5_I_1, ...: Relative intensities Example: >>> strain, stress = model.simulate_laos(gamma_0=0.5, omega=1.0) >>> harmonics = model.extract_laos_harmonics(stress) >>> print(f"Third harmonic ratio: {harmonics['I_3_I_1']:.4f}") """ # Use last complete cycle for steady-state analysis stress_cycle = stress[-n_points_per_cycle:] # FFT of stress signal stress_fft = np.fft.fft(stress_cycle) n = len(stress_cycle) # Frequency indices for harmonics # Fundamental is at index 1 (one complete cycle in the window) fundamental_idx = 1 # Extract harmonic amplitudes (magnitude) and phases harmonics = {} # Fundamental (n=1) I_1 = 2.0 * np.abs(stress_fft[fundamental_idx]) / n phi_1 = np.angle(stress_fft[fundamental_idx]) harmonics["I_1"] = I_1 harmonics["phi_1"] = phi_1 # Third harmonic (n=3) idx_3 = 3 * fundamental_idx if idx_3 < n // 2: I_3 = 2.0 * np.abs(stress_fft[idx_3]) / n phi_3 = np.angle(stress_fft[idx_3]) else: I_3 = 0.0 phi_3 = 0.0 harmonics["I_3"] = I_3 harmonics["phi_3"] = phi_3 # Fifth harmonic (n=5) idx_5 = 5 * fundamental_idx if idx_5 < n // 2: I_5 = 2.0 * np.abs(stress_fft[idx_5]) / n phi_5 = np.angle(stress_fft[idx_5]) else: I_5 = 0.0 phi_5 = 0.0 harmonics["I_5"] = I_5 harmonics["phi_5"] = phi_5 # Seventh harmonic (n=7) idx_7 = 7 * fundamental_idx if idx_7 < n // 2: I_7 = 2.0 * np.abs(stress_fft[idx_7]) / n phi_7 = np.angle(stress_fft[idx_7]) else: I_7 = 0.0 phi_7 = 0.0 harmonics["I_7"] = I_7 harmonics["phi_7"] = phi_7 # Relative intensities if I_1 > 0: harmonics["I_3_I_1"] = I_3 / I_1 harmonics["I_5_I_1"] = I_5 / I_1 harmonics["I_7_I_1"] = I_7 / I_1 else: harmonics["I_3_I_1"] = 0.0 harmonics["I_5_I_1"] = 0.0 harmonics["I_7_I_1"] = 0.0 return harmonics
[docs] def compute_chebyshev_coefficients( self, strain: np.ndarray, stress: np.ndarray, gamma_0: float, omega: float, n_points_per_cycle: int = 256, ) -> dict: """Compute Chebyshev decomposition of LAOS response. Decomposes stress into elastic and viscous Chebyshev contributions: sigma(gamma, gamma_dot) = sum_n e_n * T_n(gamma/gamma_0) + sum_n v_n * T_n(gamma_dot/gamma_dot_0) where T_n are Chebyshev polynomials of the first kind. Physical interpretation: - e_n: Elastic (in-phase with strain) Chebyshev coefficients - v_n: Viscous (out-of-phase with strain) Chebyshev coefficients - e_3/e_1 > 0: Strain stiffening - e_3/e_1 < 0: Strain softening - v_3/v_1 > 0: Shear thickening - v_3/v_1 < 0: Shear thinning Args: strain: Strain array from simulate_laos() stress: Stress array from simulate_laos() gamma_0: Strain amplitude omega: Angular frequency n_points_per_cycle: Points per oscillation cycle Returns: Dictionary containing: - e_1, e_3, e_5: Elastic Chebyshev coefficients - v_1, v_3, v_5: Viscous Chebyshev coefficients - e_3_e_1, v_3_v_1: Normalized coefficients Example: >>> strain, stress = model.simulate_laos(gamma_0=0.5, omega=1.0) >>> chebyshev = model.compute_chebyshev_coefficients( ... strain, stress, gamma_0=0.5, omega=1.0 ... ) >>> print(f"Strain stiffening ratio: {chebyshev['e_3_e_1']:.4f}") """ # Use last complete cycle strain_cycle = strain[-n_points_per_cycle:] stress_cycle = stress[-n_points_per_cycle:] # Normalize strain to [-1, 1] for Chebyshev basis gamma_norm = strain_cycle / gamma_0 # Compute strain rate dt = 2.0 * np.pi / (omega * n_points_per_cycle) gamma_dot = np.gradient(strain_cycle, dt) gamma_dot_0 = gamma_0 * omega gamma_dot_norm = gamma_dot / gamma_dot_0 # Chebyshev polynomials T_n(x) def T_1(x): return x def T_3(x): return 4 * x**3 - 3 * x def T_5(x): return 16 * x**5 - 20 * x**3 + 5 * x # Elastic coefficients (project onto strain-dependent basis) e_1 = 2.0 * np.mean(stress_cycle * T_1(gamma_norm)) e_3 = 2.0 * np.mean(stress_cycle * T_3(gamma_norm)) e_5 = 2.0 * np.mean(stress_cycle * T_5(gamma_norm)) # Viscous coefficients (project onto strain-rate-dependent basis) v_1 = 2.0 * np.mean(stress_cycle * T_1(gamma_dot_norm)) v_3 = 2.0 * np.mean(stress_cycle * T_3(gamma_dot_norm)) v_5 = 2.0 * np.mean(stress_cycle * T_5(gamma_dot_norm)) # Build result dictionary chebyshev = { "e_1": e_1, "e_3": e_3, "e_5": e_5, "v_1": v_1, "v_3": v_3, "v_5": v_5, } # Normalized coefficients (standard LAOS metrics) if abs(e_1) > 1e-12: chebyshev["e_3_e_1"] = e_3 / e_1 chebyshev["e_5_e_1"] = e_5 / e_1 else: chebyshev["e_3_e_1"] = 0.0 chebyshev["e_5_e_1"] = 0.0 if abs(v_1) > 1e-12: chebyshev["v_3_v_1"] = v_3 / v_1 chebyshev["v_5_v_1"] = v_5 / v_1 else: chebyshev["v_3_v_1"] = 0.0 chebyshev["v_5_v_1"] = 0.0 return chebyshev
[docs] def get_lissajous_curve( self, gamma_0: float, omega: float, n_points: int = 256, normalized: bool = False, ) -> tuple[np.ndarray, np.ndarray]: """Generate Lissajous curve (stress vs strain) for LAOS. Args: gamma_0: Strain amplitude omega: Angular frequency (rad/s) n_points: Number of points in curve normalized: If True, normalize strain and stress Returns: strain: Strain array (one period) stress: Stress array (one period) Example: >>> strain, stress = model.get_lissajous_curve(gamma_0=0.1, omega=1.0) >>> plt.plot(strain, stress) # Elastic Lissajous """ # Simulate two cycles strain, stress = self.simulate_laos( gamma_0, omega, n_cycles=2, n_points_per_cycle=n_points ) # Use last cycle for steady-state strain_cycle = strain[-n_points:] stress_cycle = stress[-n_points:] if normalized: strain_cycle = strain_cycle / gamma_0 stress_max = np.max(np.abs(stress_cycle)) if stress_max > 0: stress_cycle = stress_cycle / stress_max return strain_cycle, stress_cycle
# ========================================================================= # Dynamic x Evolution Methods (User Story 4) # ========================================================================= def _poisson_bracket_3d(self, state: np.ndarray) -> np.ndarray: """Compute 3D Poisson bracket operator L(z) for dynamic x mode. The 3x3 Poisson bracket maintains antisymmetry and decouples x from reversible dynamics: L = [[0, L_12, 0], [-L_12, 0, 0], [0, 0, 0]] Args: state: State vector [sigma, lambda, x] Returns: 3x3 antisymmetric Poisson bracket matrix L """ lam = np.clip(state[1], 0.01, 1.0) x = state[2] if len(state) > 2 else self.parameters.get_value("x") G0_val = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") assert G0_val is not None assert tau0 is not None assert x is not None G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability # Coupling strength for stress-strain relationship G_eff = G0_val * G0_dim * lam L_12 = G_eff / tau0 # 3x3 antisymmetric Poisson bracket (x decoupled) L = np.array([[0.0, L_12, 0.0], [-L_12, 0.0, 0.0], [0.0, 0.0, 0.0]]) return L def _friction_matrix_3d( self, state: np.ndarray, gamma_dot: float = 1.0 ) -> np.ndarray: """Compute 3D friction matrix M(z) for dynamic x mode. Block-diagonal structure for PSD guarantee: M = [[M_11, M_12, 0], [M_12, M_22, 0], [0, 0, M_33]] Args: state: State vector [sigma, lambda, x] gamma_dot: Current shear rate (for M_33 calculation) Returns: 3x3 symmetric positive semi-definite friction matrix M """ lam = np.clip(state[1], 0.01, 1.0) x = state[2] if len(state) > 2 else self.parameters.get_value("x") G0_val = self.parameters.get_value("G0") tau0 = self.parameters.get_value("tau0") assert G0_val is not None assert tau0 is not None assert x is not None G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability # Effective modulus and relaxation rate G_eff = G0_val * G0_dim * lam gamma_relax = 1.0 / tau0 # Yielding factor (Arrhenius-like) yielding_factor = np.exp(-1.0 / x) # 2x2 block components (same as 2D) M_11 = yielding_factor * gamma_relax * G_eff M_22 = yielding_factor * gamma_relax * lam * (1.0 - lam) M_12 = 0.0 # Decoupled for simplicity # Thixotropy modification to M_22 if enabled if self._thixotropy_enabled: k_build = self.parameters.get_value("k_build") k_break = self.parameters.get_value("k_break") M_22 = k_build * (1.0 - lam) + k_break * np.abs(gamma_dot) * lam # M_33: x-related dissipation (aging/rejuvenation) alpha_aging = self.parameters.get_value("alpha_aging") beta_rejuv = self.parameters.get_value("beta_rejuv") M_33 = alpha_aging + beta_rejuv * np.abs(gamma_dot) # Block-diagonal 3x3 friction matrix M = np.array([[M_11, M_12, 0.0], [M_12, M_22, 0.0], [0.0, 0.0, M_33]]) return M
[docs] def evolve_x( self, t: np.ndarray, gamma_dot: np.ndarray, x0: float | None = None, ) -> np.ndarray: """Evolve noise temperature x(t) for aging/rejuvenation dynamics. Evolution equation:: dx/dt = alpha_aging * (x_eq - x) + beta_rejuv * abs(gamma_dot) * (x_ss - x) where ``x_ss = x_eq + x_ss_A * abs(gamma_dot)^x_ss_n`` is the steady-state value under shear. Args: t: Time array (s) gamma_dot: Shear rate array (1/s), same shape as t x0: Initial noise temperature. If None, uses current x value. Returns: x_t: Noise temperature evolution, same shape as t Example: >>> model = SGRGeneric(dynamic_x=True) >>> t = np.linspace(0, 100, 1000) >>> gamma_dot = np.where(t < 50, 10.0, 0.0) # Shear then rest >>> x_t = model.evolve_x(t, gamma_dot, x0=1.0) """ if not self._dynamic_x: raise ValueError( "Dynamic x not enabled. Create model with SGRGeneric(dynamic_x=True)." ) if t.shape != gamma_dot.shape: raise ValueError( f"Time and shear rate arrays must have same shape: " f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}" ) # Get parameters (raise ValueError instead of bare assert per P2-4) x_eq = self.parameters.get_value("x_eq") alpha_aging = self.parameters.get_value("alpha_aging") beta_rejuv = self.parameters.get_value("beta_rejuv") x_ss_A = self.parameters.get_value("x_ss_A") x_ss_n = self.parameters.get_value("x_ss_n") for _name, _val in [ ("x_eq", x_eq), ("alpha_aging", alpha_aging), ("beta_rejuv", beta_rejuv), ("x_ss_A", x_ss_A), ("x_ss_n", x_ss_n), ]: if _val is None: raise ValueError( f"Parameter '{_name}' is None — set it before calling evolve_x()." ) # Narrow types for mypy after None-check loop above assert x_eq is not None and alpha_aging is not None assert beta_rejuv is not None and x_ss_A is not None and x_ss_n is not None if x0 is None: x0 = self.parameters.get_value("x") if x0 is None: raise ValueError( "Initial x0 is None — provide x0 or set the 'x' parameter." ) # Integrate using Euler method dt = np.diff(t) dt = np.concatenate([[0], dt]) x_t = np.zeros_like(t) x_t[0] = x0 for i in range(1, len(t)): gamma_dot_abs = np.abs(gamma_dot[i]) # Steady-state x under shear x_ss = x_eq + x_ss_A * np.power(gamma_dot_abs + 1e-12, x_ss_n) # Evolution: aging toward x_eq at rest, rejuvenation toward x_ss under shear dx_dt = alpha_aging * (x_eq - x_t[i - 1]) + beta_rejuv * gamma_dot_abs * ( x_ss - x_t[i - 1] ) x_t[i] = x_t[i - 1] + dx_dt * dt[i] # Clamp to physical range x_t[i] = np.clip(x_t[i], 0.5, 3.0) return x_t
[docs] def free_energy_gradient(self, state: np.ndarray) -> np.ndarray: """Compute gradient dF/dz of free energy. The gradient components are: - dF/d(sigma): Conjugate to stress (strain-like) - dF/d(lambda): Conjugate to structure (chemical potential-like) - dF/d(x) = -S: Conjugate to temperature (dynamic x mode only) Args: state: State vector [sigma, lambda] or [sigma, lambda, x] Returns: Gradient [dF/d(sigma), dF/d(lambda)] or [dF/d(sigma), dF/d(lambda), dF/d(x)] """ sigma = state[0] lam = np.clip(state[1], 0.01, 1.0 - 1e-10) # Get x from state or parameters if len(state) > 2 and self._dynamic_x: x = state[2] else: x = self.parameters.get_value("x") G0_val = self.parameters.get_value("G0") assert G0_val is not None assert x is not None G0_dim = G0(x) # R10-SGR-005: removed float() to preserve JAX traceability G_eff = G0_val * G0_dim * lam # dU/d(sigma) = sigma / G_eff dU_dsigma = sigma / (G_eff + 1e-20) # dU/d(lambda) = -sigma^2 / (2 * G_eff^2) * G0_val * G0_dim dU_dlam = -(sigma**2) / (2.0 * (G_eff + 1e-20) ** 2) * G0_val * G0_dim # dS/d(lambda) = -ln(lambda) + ln(1-lambda) = ln((1-lambda)/lambda) dS_dlam = np.log((1.0 - lam) / lam) # dF/dz = dU/dz - T * dS/dz dF_dsigma = dU_dsigma dF_dlam = dU_dlam - x * dS_dlam if len(state) > 2 and self._dynamic_x: # dF/dx = -S for dynamic x mode # S = -[lambda * ln(lambda) + (1-lambda) * ln(1-lambda)] S = -(lam * np.log(lam) + (1.0 - lam) * np.log(1.0 - lam)) dF_dx = -S return np.array([dF_dsigma, dF_dlam, dF_dx]) else: return np.array([dF_dsigma, dF_dlam])
[docs] def compute_entropy_production(self, state: np.ndarray) -> float: """Compute entropy production rate W at given state. The entropy production is: W = (dF/dz)^T * M(z) * (dF/dz) >= 0 This must be non-negative (second law of thermodynamics). Args: state: State vector [sigma, lambda] or [sigma, lambda, x] Returns: Entropy production rate W (must be >= 0) Raises: Warning if W < 0 due to numerical errors """ # Use appropriate operators based on state dimension if len(state) > 2 and self._dynamic_x: M = self._friction_matrix_3d(state) else: M = self.friction_matrix(state) dF_dz = self.free_energy_gradient(state) # W = dF^T M dF (quadratic form) W = dF_dz @ M @ dF_dz # Check thermodynamic consistency if W < -1e-12: logger.warning( f"Entropy production W = {W:.6e} < 0 at state={state}. " "This violates the second law and may indicate numerical issues." ) return max(W, 0.0) # Ensure non-negative for downstream use
[docs] def verify_thermodynamic_consistency( self, state: np.ndarray, tol: float = 1e-10 ) -> dict: """Verify all GENERIC thermodynamic consistency conditions. Checks: 1. Poisson bracket antisymmetry: L = -L^T 2. Friction matrix symmetry: M = M^T 3. Friction matrix positive semi-definiteness: eigenvalues >= 0 4. Entropy production non-negativity: W >= 0 Args: state: State vector [sigma, lambda] or [sigma, lambda, x] tol: Numerical tolerance for consistency checks Returns: Dictionary with consistency check results """ # Use appropriate operators based on state dimension if len(state) > 2 and self._dynamic_x: L = self._poisson_bracket_3d(state) M = self._friction_matrix_3d(state) else: L = self.poisson_bracket(state) M = self.friction_matrix(state) results = {} # 1. Poisson bracket antisymmetry antisym_error = np.max(np.abs(L + L.T)) results["poisson_antisymmetric"] = antisym_error < tol results["poisson_antisymmetry_error"] = antisym_error # 2. Friction matrix symmetry sym_error = np.max(np.abs(M - M.T)) results["friction_symmetric"] = sym_error < tol results["friction_symmetry_error"] = sym_error # 3. Friction matrix positive semi-definiteness eigenvalues = np.linalg.eigvalsh(M) min_eig = np.min(eigenvalues) results["friction_positive_semidefinite"] = min_eig >= -tol results["friction_min_eigenvalue"] = min_eig # 4. Entropy production non-negativity W = self.compute_entropy_production(state) results["entropy_production_nonnegative"] = W >= -tol results["entropy_production"] = W # 5. Overall consistency results["thermodynamically_consistent"] = all( [ results["poisson_antisymmetric"], results["friction_symmetric"], results["friction_positive_semidefinite"], results["entropy_production_nonnegative"], ] ) return results