Source code for rheojax.models.flow.carreau

"""Carreau model for non-Newtonian flow.

This module implements the Carreau model for fluids with smooth transition
from Newtonian behavior at low shear rates to power-law behavior at high
shear rates (ROTATION test mode).

Theory:
    η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (λγ̇)²]^((n-1)/2)

    - η_0: Zero-shear viscosity (Newtonian plateau at low γ̇)
    - η_∞: Infinite-shear viscosity (Newtonian plateau at high γ̇)
    - λ: Time constant (characteristic relaxation time)
    - n: Power-law index (controls transition steepness)

References:
    - Carreau, P.J. (1972). Trans. Soc. Rheol. 16, 99-127.
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()

import numpy as np

from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger, log_fit

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "carreau", protocols=[Protocol.FLOW_CURVE], deformation_modes=[DeformationMode.SHEAR], ) class Carreau(BaseModel): """Carreau model for non-Newtonian flow (ROTATION only). The Carreau model describes the smooth transition from a Newtonian plateau at low shear rates (zero-shear viscosity η_0) to power-law shear-thinning behavior at high shear rates, with an optional infinite-shear viscosity plateau. Parameters: eta0: Zero-shear viscosity (Pa·s), Newtonian plateau at low γ̇ eta_inf: Infinite-shear viscosity (Pa·s), Newtonian plateau at high γ̇ lambda_: Time constant (s), characteristic relaxation time n: Power-law index (dimensionless), controls transition steepness Constitutive Equation: η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (``λ`` γ̇)²]^((n-1)/2) Special Cases: ``λ`` → 0: Newtonian fluid with η = η_0 ``λ`` → ∞, η_∞ = 0: Power Law behavior n = 1: Newtonian fluid for all shear rates Test Mode: ROTATION (steady shear) only """
[docs] def __init__(self): """Initialize Carreau model.""" super().__init__() self.parameters = ParameterSet() self.parameters.add( name="eta0", value=1000.0, bounds=(1e-3, 1e12), units="Pa·s", description="Zero-shear viscosity", ) self.parameters.add( name="eta_inf", value=0.001, bounds=(1e-6, 1e6), units="Pa·s", description="Infinite-shear viscosity", ) self.parameters.add( name="lambda_", value=1.0, bounds=(1e-6, 1e6), units="s", description="Time constant", ) self.parameters.add( name="n", value=0.5, bounds=(0.01, 1.0), units="dimensionless", description="Power-law index", )
def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> Carreau: """Fit Carreau parameters to data. Args: X: Shear rate data (gamma_dot) y: Viscosity data **kwargs: Additional fitting options Returns: self for method chaining """ # P3-FLOW-001: Cache test_mode for Bayesian _resolve_test_mode() consistency self._test_mode = kwargs.get("test_mode", "rotation") with log_fit( logger, self.__class__.__name__, data_shape=X.shape, test_mode="ROTATION", ) as ctx: logger.debug( "Processing input data", gamma_dot_range=(float(X.min()), float(X.max())), viscosity_range=(float(y.min()), float(y.max())), n_points=len(X), ) # Simple heuristic fitting # eta0: maximum viscosity at low shear rates # eta_inf: minimum viscosity at high shear rates # lambda_: shear rate at which viscosity is halfway between plateaus # n: slope of power-law region # Sort by shear rate sort_idx = np.argsort(X) X_sorted = X[sort_idx] y_sorted = y[sort_idx] # Estimate plateaus try: eta0_est = np.max( y_sorted[: len(y_sorted) // 10 + 1] ) # Average low shear eta_inf_est = np.min( y_sorted[-len(y_sorted) // 10 :] ) # Average high shear logger.debug( "Plateau estimates", eta0_est=eta0_est, eta_inf_est=eta_inf_est, ) # Find characteristic shear rate (midpoint) eta_mid = (eta0_est + eta_inf_est) / 2.0 idx_mid = np.argmin(np.abs(y_sorted - eta_mid)) lambda_est = 1.0 / X_sorted[idx_mid] if X_sorted[idx_mid] > 0 else 1.0 logger.debug( "Characteristic time estimate", eta_mid=eta_mid, gamma_dot_mid=float(X_sorted[idx_mid]), lambda_est=lambda_est, ) # Estimate n from power-law region slope # Use middle region of data mid_start = len(X_sorted) // 3 mid_end = 2 * len(X_sorted) // 3 if mid_end > mid_start + 1: log_gamma = np.log(np.maximum(X_sorted[mid_start:mid_end], 1e-30)) log_eta = np.log(np.maximum(y_sorted[mid_start:mid_end], 1e-30)) coeffs = np.polyfit(log_gamma, log_eta, 1) n_est = coeffs[0] + 1.0 # Slope is n-1 logger.debug( "Power-law index from regression", slope=coeffs[0], n_est=n_est, ) else: n_est = 0.5 logger.debug( "Using default power-law index (insufficient data for regression)", n_est=n_est, ) except Exception as e: logger.error( "Parameter estimation failed", error_type=type(e).__name__, error_message=str(e), exc_info=True, ) raise # Clip to bounds eta0_est = np.clip(eta0_est, 1e-3, 1e12) eta_inf_est = np.clip(eta_inf_est, 1e-6, 1e6) lambda_est = np.clip(lambda_est, 1e-6, 1e6) n_est = np.clip(n_est, 0.01, 1.0) # Ensure eta0 > eta_inf if eta0_est <= eta_inf_est: logger.debug( "Adjusting eta0 (was less than or equal to eta_inf)", eta0_original=eta0_est, eta_inf=eta_inf_est, ) eta0_est = eta_inf_est * 10.0 self.parameters.set_value("eta0", float(eta0_est)) self.parameters.set_value("eta_inf", float(eta_inf_est)) self.parameters.set_value("lambda_", float(lambda_est)) self.parameters.set_value("n", float(n_est)) # Log fitted parameters ctx["eta0"] = float(eta0_est) ctx["eta_inf"] = float(eta_inf_est) ctx["lambda_"] = float(lambda_est) ctx["n"] = float(n_est) logger.debug( "Fitting completed successfully", eta0=float(eta0_est), eta_inf=float(eta_inf_est), lambda_=float(lambda_est), n=float(n_est), ) return self def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Predict viscosity for given shear rates. Args: X: Shear rate data (γ̇) Returns: Predicted viscosity η(γ̇) """ eta0 = self.parameters.get_value("eta0") eta_inf = self.parameters.get_value("eta_inf") lambda_ = self.parameters.get_value("lambda_") n = self.parameters.get_value("n") # Convert to JAX for computation gamma_dot = jnp.array(X) # Compute viscosity viscosity = self._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, n) # Convert back to numpy return np.array(viscosity)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (shear rate γ̇) params: Array of parameter values [eta0, eta_inf, ``lambda_``, n] Returns: Model predictions as JAX array """ # Extract parameters from array (in order they were added to ParameterSet) eta0 = params[0] eta_inf = params[1] lambda_ = params[2] n = params[3] # Flow model only supports ROTATION test mode # Compute prediction using the internal JAX method return self._predict_viscosity(X, eta0, eta_inf, lambda_, n)
@staticmethod @jax.jit def _predict_viscosity( gamma_dot: jnp.ndarray, eta0: float, eta_inf: float, lambda_: float, n: float, ) -> jnp.ndarray: """Compute viscosity using Carreau model. Args: gamma_dot: Shear rate (s^-1) eta0: Zero-shear viscosity (Pa·s) eta_inf: Infinite-shear viscosity (Pa·s) lambda_: Time constant (s) n: Power-law index Returns: Viscosity (Pa·s) """ # η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (λγ̇)²]^((n-1)/2) lambda_gamma = lambda_ * jnp.abs(gamma_dot) factor = jnp.power(1.0 + lambda_gamma**2, (n - 1.0) / 2.0) return eta_inf + (eta0 - eta_inf) * factor @staticmethod @jax.jit def _predict_stress( gamma_dot: jnp.ndarray, eta0: float, eta_inf: float, lambda_: float, n: float, ) -> jnp.ndarray: """Compute shear stress using Carreau model. Args: gamma_dot: Shear rate (s^-1) eta0: Zero-shear viscosity (Pa·s) eta_inf: Infinite-shear viscosity (Pa·s) lambda_: Time constant (s) n: Power-law index Returns: Shear stress (Pa) """ # σ(γ̇) = η(γ̇) * γ̇ viscosity = Carreau._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, n) return viscosity * jnp.abs(gamma_dot)
[docs] def predict_stress(self, gamma_dot: np.ndarray) -> np.ndarray: """Predict shear stress for given shear rates. Args: gamma_dot: Shear rate data (γ̇) Returns: Predicted shear stress σ(γ̇) """ eta0 = self.parameters.get_value("eta0") eta_inf = self.parameters.get_value("eta_inf") lambda_ = self.parameters.get_value("lambda_") n = self.parameters.get_value("n") # Convert to JAX for computation gamma_dot_jax = jnp.array(gamma_dot) # Compute stress stress = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, n) # Convert back to numpy return np.array(stress)
[docs] def predict_rheo( self, rheo_data: RheoData, test_mode: TestMode | None = None, output: str = "viscosity", ) -> RheoData: """Predict rheological response for RheoData. Args: rheo_data: Input rheological data test_mode: Test mode (must be ROTATION) output: Output type ('viscosity' or 'stress') Returns: RheoData with predictions Raises: ValueError: If test mode is not ROTATION """ # Detect test mode if not provided if test_mode is None: test_mode = detect_test_mode(rheo_data) # Validate test mode if test_mode != TestMode.ROTATION: raise ValueError( f"Carreau model only supports ROTATION test mode, got {test_mode}" ) # Get shear rate data gamma_dot = rheo_data.x # Get parameters eta0 = self.parameters.get_value("eta0") eta_inf = self.parameters.get_value("eta_inf") lambda_ = self.parameters.get_value("lambda_") n = self.parameters.get_value("n") # Convert to JAX gamma_dot_jax = jnp.array(gamma_dot) # Compute prediction based on output type if output == "viscosity": y_pred = self._predict_viscosity(gamma_dot_jax, eta0, eta_inf, lambda_, n) y_units = "Pa·s" elif output == "stress": y_pred = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, n) y_units = "Pa" else: raise ValueError( f"Invalid output type: {output}. Must be 'viscosity' or 'stress'" ) # Convert back to numpy y_pred = np.array(y_pred) # Create output RheoData return RheoData( x=np.array(gamma_dot), y=y_pred, x_units=rheo_data.x_units or "1/s", y_units=y_units, domain="time", metadata={ "model": "Carreau", "test_mode": TestMode.ROTATION, "output": output, "eta0": eta0, "eta_inf": eta_inf, "lambda_": lambda_, "n": n, }, validate=False, )
[docs] def __repr__(self) -> str: """String representation.""" eta0 = self.parameters.get_value("eta0") eta_inf = self.parameters.get_value("eta_inf") lambda_ = self.parameters.get_value("lambda_") n = self.parameters.get_value("n") return f"Carreau(eta0={eta0:.3e}, eta_inf={eta_inf:.3e}, lambda={lambda_:.3e}, n={n:.3f})"
__all__ = ["Carreau"]