Source code for rheojax.models.stz._base

"""Base class for STZ models.

Provides shared infrastructure for:
1. Parameter initialization based on complexity variants
2. JAX-based ODE system definitions (Flow, Transient, LAOS)
3. Integration with Diffrax for time-stepping
"""

from __future__ import annotations

from typing import Literal

import numpy as np

from rheojax.core.base import BaseModel
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.logging import get_logger
from rheojax.models.stz._kernels import stz_ode_rhs

# Safe JAX import
jax, jnp = safe_import_jax()

# Module logger
logger = get_logger(__name__)

# Type definitions
VariantType = Literal["minimal", "standard", "full"]


[docs] class STZBase(BaseModel): """Base class for Shear Transformation Zone (STZ) models. Implements the core state evolution logic and parameter management for different model variants. Attributes: variant: Model complexity variant ('minimal', 'standard', 'full') parameters: ParameterSet containing model constants """
[docs] def __init__(self, variant: VariantType = "standard"): """Initialize STZ Base Model. Args: variant: Complexity variant. - 'minimal': chi only (2 state vars: stress, chi) - 'standard': chi + Lambda (3 state vars) [Default] - 'full': chi + Lambda + m (4 state vars) """ super().__init__() self.variant = variant self._setup_parameters() # Internal storage for trajectories (for plotting/debugging) self._trajectory: dict[str, np.ndarray] | None = None
def _setup_parameters(self): """Initialize ParameterSet based on selected variant.""" self.parameters = ParameterSet() # --- Common Parameters (All Variants) --- # G0: Elastic modulus (Pa) - renamed from G_inf to match kernel args self.parameters.add( name="G0", value=1e9, bounds=(1e6, 1e12), units="Pa", description="High-frequency elastic modulus", ) # sigma_y: Yield stress (Pa) - scales the activation barrier self.parameters.add( name="sigma_y", value=1e6, bounds=(1e3, 1e9), units="Pa", description="Yield stress (characteristic stress scale)", ) # chi_inf: Steady-state effective temperature (dimensionless) # Represents the structural disorder at high shear rates # Metallic glasses: 0.5-2.0; colloids/soft matter: 0.05-0.5 self.parameters.add( name="chi_inf", value=0.1, bounds=(0.01, 2.0), units="dimensionless", description="Steady-state effective temperature", ) # tau0: Molecular attempt time (s) # Metallic glasses: 1e-13-1e-9; colloids/soft matter: 1e-4-1e-1 self.parameters.add( name="tau0", value=1e-12, bounds=(1e-14, 1e0), units="s", description="Molecular vibration timescale", ) # epsilon0: Characteristic strain increment self.parameters.add( name="epsilon0", value=0.1, bounds=(0.01, 1.0), units="dimensionless", description="Characteristic strain increment per STZ event", ) # c0: Specific heat (dimensionless) - controls chi evolution rate self.parameters.add( name="c0", value=1.0, bounds=(0.1, 100.0), units="dimensionless", description="Specific heat parameter (controls chi rate)", ) # Activation energy barrier (dimensionless, scaled by chi) # Often fixed to 1.0 in theoretical treatments, but can be fit self.parameters.add( name="ez", value=1.0, bounds=(0.1, 5.0), units="dimensionless", description="STZ formation energy (normalized)", ) # --- Variant Specific Parameters --- if self.variant in ["standard", "full"]: # Lambda dynamics included # Relaxation time for Lambda self.parameters.add( name="tau_beta", value=1.0, bounds=(0.01, 100.0), units="s", description="Relaxation timescale for STZ density", ) if self.variant == "full": # Back stress / orientation parameters self.parameters.add( name="m_inf", value=0.1, bounds=(0.0, 0.5), units="dimensionless", description="Saturation value for orientational bias", ) self.parameters.add( name="rate_m", value=1.0, bounds=(0.1, 100.0), units="dimensionless", description="Rate coefficient for orientational bias", )
[docs] def get_initial_state(self, stress_init: float = 0.0) -> jnp.ndarray: """Get initial state vector based on variant. Args: stress_init: Initial stress value. Returns: Initial state vector y0. """ # Default initial conditions for internal variables # chi_init: start at annealed state (low chi) chi_init = 0.05 # Lambda_init: Equilibrium at chi_init ez_val = self.parameters.get_value("ez") ez = ez_val if ez_val is not None else 1.0 # Avoid div by zero if chi_init is 0 (unlikely) safe_chi = max(chi_init, 1e-6) lambda_init = jnp.exp(-ez / safe_chi) if self.variant == "minimal": # State: [stress, chi] return jnp.array([stress_init, chi_init]) elif self.variant == "standard": # State: [stress, chi, Lambda] return jnp.array([stress_init, chi_init, lambda_init]) elif self.variant == "full": # State: [stress, chi, Lambda, m] m_init = 0.0 return jnp.array([stress_init, chi_init, lambda_init, m_init]) else: raise ValueError(f"Unknown variant: {self.variant}")
def _ode_term_strain_controlled( self, t: float, y: jnp.ndarray, args: dict ) -> jnp.ndarray: """ODE vector field wrapper for Diffrax. Delegates to the JAX-compiled kernel stz_ode_rhs. Args: t: Time y: State vector args: Dictionary of parameters and inputs Returns: dy/dt """ return stz_ode_rhs(t, y, args)