Source code for rheojax.models.classical.zener

"""Zener (Standard Linear Solid) viscoelastic model.

The Zener model, also known as the Standard Linear Solid (SLS), consists of
a Maxwell element (spring G_m and dashpot eta in series) in parallel with
an equilibrium spring G_e.

Theory:
    - Total modulus: G_total = G_e + G_m
    - Relaxation modulus: G(t) = G_e + G_m * exp(-t/tau) where tau = eta/G_m
    - Complex modulus: G*(omega) = G_e + G_m*(omega*tau)^2/(1+(omega*tau)^2) + i*G_m*omega*tau/(1+(omega*tau)^2)
    - Creep compliance: J(t) = 1/(G_e+G_m) + (G_m/(G_e*(G_e+G_m))) * (1 - exp(-t/tau_c))
      where tau_c = eta * (G_e + G_m) / (G_e * G_m)
    - Steady shear viscosity: eta(gamma_dot) = eta (constant)

References:
    - Ferry, J. D. (1980). Viscoelastic properties of polymers.
    - Tschoegl, N. W. (1989). The phenomenological theory of linear viscoelastic behavior.
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()


from rheojax.core.base import BaseModel
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "zener", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.FLOW_CURVE, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class Zener(BaseModel): """Zener (Standard Linear Solid) viscoelastic model. The Zener model consists of a Maxwell element (spring G_m and dashpot eta) in parallel with an equilibrium spring G_e. This provides both instantaneous elastic response and time-dependent relaxation to a finite equilibrium modulus. Parameters: Ge (float): Equilibrium modulus in Pa, range [1e-3, 1e9], default 1e4 Gm (float): Maxwell modulus in Pa, range [1e-3, 1e9], default 1e5 eta (float): Viscosity in Pa·s, range [1e-6, 1e12], default 1e3 Supported test modes: - Relaxation: Stress relaxation under constant strain - Creep: Strain development under constant stress - Oscillation: Small amplitude oscillatory shear (SAOS) - Rotation: Steady shear flow Example: >>> from rheojax.models.zener import Zener >>> from rheojax.core.data import RheoData >>> import jax.numpy as jnp >>> >>> # Create model >>> model = Zener() >>> model.parameters.set_value('Ge', 1e4) >>> model.parameters.set_value('Gm', 1e5) >>> model.parameters.set_value('eta', 1e3) >>> >>> # Predict relaxation >>> t = jnp.linspace(0.01, 10, 100) >>> data = RheoData(x=t, y=jnp.zeros_like(t), domain='time') >>> G_t = model.predict(data) """
[docs] def __init__(self): """Initialize Zener model with default parameters.""" super().__init__() # Define parameters with physical bounds self.parameters = ParameterSet() self.parameters.add( name="Ge", value=1e4, bounds=(1e-3, 1e9), units="Pa", description="Equilibrium modulus", ) self.parameters.add( name="Gm", value=1e5, bounds=(1e-3, 1e9), units="Pa", description="Maxwell modulus", ) self.parameters.add( name="eta", value=1e3, bounds=(1e-6, 1e12), units="Pa·s", description="Viscosity", ) self.fitted_ = False self._test_mode = TestMode.RELAXATION # Store test mode for model_function
def _fit(self, X, y, **kwargs): """Fit Zener model to data. Args: X: RheoData object or independent variable array y: Dependent variable array (if X is not RheoData) **kwargs: Additional fitting options Returns: self for method chaining """ def model_fn(x, params): """Model function for optimization (stateless).""" Ge, Gm, eta = params[0], params[1], params[2] # self._test_mode is set by _standard_nlsq_fit before this is called tm = self._test_mode if tm == TestMode.RELAXATION: return self._predict_relaxation(x, Ge, Gm, eta) elif tm == TestMode.CREEP: return self._predict_creep(x, Ge, Gm, eta) elif tm == TestMode.OSCILLATION: return self._predict_oscillation(x, Ge, Gm, eta) elif tm == TestMode.ROTATION: return self._predict_rotation(x, Ge, Gm, eta) else: raise ValueError(f"Unsupported test mode: {tm}") return self._standard_nlsq_fit( X, y, model_fn, default_test_mode=TestMode.RELAXATION, **kwargs ) def _predict(self, X, **kwargs): """Predict response based on input data. Args: X: RheoData object or independent variable array Returns: Predicted values as JAX array """ # Handle RheoData input if isinstance(X, RheoData): rheo_data = X test_mode = detect_test_mode(rheo_data) x_data = jnp.array(rheo_data.x) else: x_data = jnp.array(X) # Use test_mode from last fit if available, otherwise default to RELAXATION test_mode = getattr(self, "_test_mode", TestMode.RELAXATION) # Get parameter values Ge = self.parameters.get_value("Ge") Gm = self.parameters.get_value("Gm") eta = self.parameters.get_value("eta") # Dispatch to appropriate prediction method if test_mode == TestMode.RELAXATION: return self._predict_relaxation(x_data, Ge, Gm, eta) elif test_mode == TestMode.CREEP: return self._predict_creep(x_data, Ge, Gm, eta) elif test_mode == TestMode.OSCILLATION: return self._predict_oscillation(x_data, Ge, Gm, eta) elif test_mode in (TestMode.ROTATION, TestMode.FLOW_CURVE): return self._predict_rotation(x_data, Ge, Gm, eta) else: raise ValueError(f"Unsupported test mode: {test_mode}")
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time, frequency, or shear rate) params: Array of parameter values [Ge, Gm, eta] test_mode: Test mode for predictions (relaxation, creep, oscillation, rotation) Returns: Model predictions as JAX array """ # Extract parameters from array (in order they were added to ParameterSet) Ge = params[0] Gm = params[1] eta = params[2] # Use provided test_mode, or fallback to stored test mode or default if test_mode is None: test_mode = getattr(self, "_test_mode", TestMode.RELAXATION) # Dispatch to appropriate prediction method if test_mode == TestMode.RELAXATION: return self._predict_relaxation(X, Ge, Gm, eta) elif test_mode == TestMode.CREEP: return self._predict_creep(X, Ge, Gm, eta) elif test_mode == TestMode.OSCILLATION: return self._predict_oscillation(X, Ge, Gm, eta) elif test_mode in (TestMode.ROTATION, TestMode.FLOW_CURVE): return self._predict_rotation(X, Ge, Gm, eta) else: raise ValueError(f"Unsupported test mode: {test_mode}")
@staticmethod @jax.jit def _predict_relaxation( t: jnp.ndarray, Ge: float, Gm: float, eta: float ) -> jnp.ndarray: """Predict relaxation modulus G(t). Theory: G(t) = G_e + G_m * exp(-t/tau) where tau = eta/G_m Args: t: Time array (s) Ge: Equilibrium modulus (Pa) Gm: Maxwell modulus (Pa) eta: Viscosity (Pa·s) Returns: Relaxation modulus G(t) in Pa """ tau = eta / Gm # Relaxation time return Ge + Gm * jnp.exp(-t / tau) @staticmethod @jax.jit def _predict_creep(t: jnp.ndarray, Ge: float, Gm: float, eta: float) -> jnp.ndarray: """Predict creep compliance J(t). Theory: J(t) = 1/(G_e+G_m) + (G_m/(G_e*(G_e+G_m))) * (1 - exp(-t/tau_c)) where tau_c = eta * (G_e + G_m) / (G_e * G_m) is the creep retardation time Args: t: Time array (s) Ge: Equilibrium modulus (Pa) Gm: Maxwell modulus (Pa) eta: Viscosity (Pa·s) Returns: Creep compliance J(t) in 1/Pa """ G_total = Ge + Gm J_inf = 1.0 / G_total # Instantaneous compliance tau_c = eta * G_total / (Ge * Gm) # Retardation time # Creep compliance return J_inf + (Gm / (Ge * G_total)) * (1.0 - jnp.exp(-t / tau_c)) @staticmethod @jax.jit def _predict_oscillation( omega: jnp.ndarray, Ge: float, Gm: float, eta: float ) -> jnp.ndarray: """Predict complex modulus G*(omega). Theory: G'(omega) = G_e + G_m * (omega*tau)^2 / (1 + (omega*tau)^2) G''(omega) = G_m * omega*tau / (1 + (omega*tau)^2) G*(omega) = G'(omega) + i*G''(omega) Args: omega: Angular frequency array (rad/s) Ge: Equilibrium modulus (Pa) Gm: Maxwell modulus (Pa) eta: Viscosity (Pa·s) Returns: Complex modulus G*(omega) in Pa """ tau = eta / Gm # Relaxation time omega_tau = omega * tau omega_tau_sq = omega_tau**2 # Storage modulus G' G_prime = Ge + Gm * omega_tau_sq / (1.0 + omega_tau_sq) # Loss modulus G'' G_double_prime = Gm * omega_tau / (1.0 + omega_tau_sq) # Complex modulus return G_prime + 1j * G_double_prime @staticmethod @jax.jit def _predict_rotation( gamma_dot: jnp.ndarray, Ge: float, Gm: float, eta: float ) -> jnp.ndarray: """Predict steady shear stress sigma(gamma_dot). Theory: sigma = eta * gamma_dot (Newtonian flow) Args: gamma_dot: Shear rate array (1/s) Ge: Equilibrium modulus (Pa) - not used but kept for interface consistency Gm: Maxwell modulus (Pa) - not used but kept for interface consistency eta: Viscosity (Pa·s) Returns: Shear stress in Pa """ return eta * gamma_dot
[docs] def get_relaxation_time(self) -> float: """Get characteristic relaxation time tau = eta/G_m. Returns: Relaxation time in seconds """ Gm = self.parameters.get_value("Gm") eta = self.parameters.get_value("eta") return eta / Gm
[docs] def get_retardation_time(self) -> float: """Get characteristic retardation time for creep. Theory: tau_c = eta * (G_e + G_m) / (G_e * G_m) Returns: Retardation time in seconds """ Ge = self.parameters.get_value("Ge") Gm = self.parameters.get_value("Gm") eta = self.parameters.get_value("eta") return eta * (Ge + Gm) / (Ge * Gm)
[docs] def __repr__(self) -> str: """String representation of Zener model.""" Ge = self.parameters.get_value("Ge") Gm = self.parameters.get_value("Gm") eta = self.parameters.get_value("eta") tau = self.get_relaxation_time() return f"Zener(Ge={Ge:.2e} Pa, Gm={Gm:.2e} Pa, eta={eta:.2e} Pa·s, tau={tau:.2e} s)"
__all__ = ["Zener"]