"""Carreau model for non-Newtonian flow.
This module implements the Carreau model for fluids with smooth transition
from Newtonian behavior at low shear rates to power-law behavior at high
shear rates (ROTATION test mode).
Theory:
η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (λγ̇)²]^((n-1)/2)
- η_0: Zero-shear viscosity (Newtonian plateau at low γ̇)
- η_∞: Infinite-shear viscosity (Newtonian plateau at high γ̇)
- λ: Time constant (characteristic relaxation time)
- n: Power-law index (controls transition steepness)
References:
- Carreau, P.J. (1972). Trans. Soc. Rheol. 16, 99-127.
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
import numpy as np
from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger, log_fit
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"carreau",
protocols=[Protocol.FLOW_CURVE],
deformation_modes=[DeformationMode.SHEAR],
)
class Carreau(BaseModel):
"""Carreau model for non-Newtonian flow (ROTATION only).
The Carreau model describes the smooth transition from a Newtonian
plateau at low shear rates (zero-shear viscosity η_0) to power-law
shear-thinning behavior at high shear rates, with an optional
infinite-shear viscosity plateau.
Parameters:
eta0: Zero-shear viscosity (Pa·s), Newtonian plateau at low γ̇
eta_inf: Infinite-shear viscosity (Pa·s), Newtonian plateau at high γ̇
lambda_: Time constant (s), characteristic relaxation time
n: Power-law index (dimensionless), controls transition steepness
Constitutive Equation:
η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (``λ`` γ̇)²]^((n-1)/2)
Special Cases:
``λ`` → 0: Newtonian fluid with η = η_0
``λ`` → ∞, η_∞ = 0: Power Law behavior
n = 1: Newtonian fluid for all shear rates
Test Mode:
ROTATION (steady shear) only
"""
[docs]
def __init__(self):
"""Initialize Carreau model."""
super().__init__()
self.parameters = ParameterSet()
self.parameters.add(
name="eta0",
value=1000.0,
bounds=(1e-3, 1e12),
units="Pa·s",
description="Zero-shear viscosity",
)
self.parameters.add(
name="eta_inf",
value=0.001,
bounds=(1e-6, 1e6),
units="Pa·s",
description="Infinite-shear viscosity",
)
self.parameters.add(
name="lambda_",
value=1.0,
bounds=(1e-6, 1e6),
units="s",
description="Time constant",
)
self.parameters.add(
name="n",
value=0.5,
bounds=(0.01, 1.0),
units="dimensionless",
description="Power-law index",
)
def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> Carreau:
"""Fit Carreau parameters to data.
Args:
X: Shear rate data (gamma_dot)
y: Viscosity data
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
# P3-FLOW-001: Cache test_mode for Bayesian _resolve_test_mode() consistency
self._test_mode = kwargs.get("test_mode", "rotation")
with log_fit(
logger,
self.__class__.__name__,
data_shape=X.shape,
test_mode="ROTATION",
) as ctx:
logger.debug(
"Processing input data",
gamma_dot_range=(float(X.min()), float(X.max())),
viscosity_range=(float(y.min()), float(y.max())),
n_points=len(X),
)
# Simple heuristic fitting
# eta0: maximum viscosity at low shear rates
# eta_inf: minimum viscosity at high shear rates
# lambda_: shear rate at which viscosity is halfway between plateaus
# n: slope of power-law region
# Sort by shear rate
sort_idx = np.argsort(X)
X_sorted = X[sort_idx]
y_sorted = y[sort_idx]
# Estimate plateaus
try:
eta0_est = np.max(
y_sorted[: len(y_sorted) // 10 + 1]
) # Average low shear
eta_inf_est = np.min(
y_sorted[-len(y_sorted) // 10 :]
) # Average high shear
logger.debug(
"Plateau estimates",
eta0_est=eta0_est,
eta_inf_est=eta_inf_est,
)
# Find characteristic shear rate (midpoint)
eta_mid = (eta0_est + eta_inf_est) / 2.0
idx_mid = np.argmin(np.abs(y_sorted - eta_mid))
lambda_est = 1.0 / X_sorted[idx_mid] if X_sorted[idx_mid] > 0 else 1.0
logger.debug(
"Characteristic time estimate",
eta_mid=eta_mid,
gamma_dot_mid=float(X_sorted[idx_mid]),
lambda_est=lambda_est,
)
# Estimate n from power-law region slope
# Use middle region of data
mid_start = len(X_sorted) // 3
mid_end = 2 * len(X_sorted) // 3
if mid_end > mid_start + 1:
log_gamma = np.log(np.maximum(X_sorted[mid_start:mid_end], 1e-30))
log_eta = np.log(np.maximum(y_sorted[mid_start:mid_end], 1e-30))
coeffs = np.polyfit(log_gamma, log_eta, 1)
n_est = coeffs[0] + 1.0 # Slope is n-1
logger.debug(
"Power-law index from regression",
slope=coeffs[0],
n_est=n_est,
)
else:
n_est = 0.5
logger.debug(
"Using default power-law index (insufficient data for regression)",
n_est=n_est,
)
except Exception as e:
logger.error(
"Parameter estimation failed",
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
raise
# Clip to bounds
eta0_est = np.clip(eta0_est, 1e-3, 1e12)
eta_inf_est = np.clip(eta_inf_est, 1e-6, 1e6)
lambda_est = np.clip(lambda_est, 1e-6, 1e6)
n_est = np.clip(n_est, 0.01, 1.0)
# Ensure eta0 > eta_inf
if eta0_est <= eta_inf_est:
logger.debug(
"Adjusting eta0 (was less than or equal to eta_inf)",
eta0_original=eta0_est,
eta_inf=eta_inf_est,
)
eta0_est = eta_inf_est * 10.0
self.parameters.set_value("eta0", float(eta0_est))
self.parameters.set_value("eta_inf", float(eta_inf_est))
self.parameters.set_value("lambda_", float(lambda_est))
self.parameters.set_value("n", float(n_est))
# Log fitted parameters
ctx["eta0"] = float(eta0_est)
ctx["eta_inf"] = float(eta_inf_est)
ctx["lambda_"] = float(lambda_est)
ctx["n"] = float(n_est)
logger.debug(
"Fitting completed successfully",
eta0=float(eta0_est),
eta_inf=float(eta_inf_est),
lambda_=float(lambda_est),
n=float(n_est),
)
return self
def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""Predict viscosity for given shear rates.
Args:
X: Shear rate data (γ̇)
Returns:
Predicted viscosity η(γ̇)
"""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
n = self.parameters.get_value("n")
# Convert to JAX for computation
gamma_dot = jnp.array(X)
# Compute viscosity
viscosity = self._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, n)
# Convert back to numpy
return np.array(viscosity)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
Args:
X: Independent variable (shear rate γ̇)
params: Array of parameter values [eta0, eta_inf, ``lambda_``, n]
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
eta0 = params[0]
eta_inf = params[1]
lambda_ = params[2]
n = params[3]
# Flow model only supports ROTATION test mode
# Compute prediction using the internal JAX method
return self._predict_viscosity(X, eta0, eta_inf, lambda_, n)
@staticmethod
@jax.jit
def _predict_viscosity(
gamma_dot: jnp.ndarray,
eta0: float,
eta_inf: float,
lambda_: float,
n: float,
) -> jnp.ndarray:
"""Compute viscosity using Carreau model.
Args:
gamma_dot: Shear rate (s^-1)
eta0: Zero-shear viscosity (Pa·s)
eta_inf: Infinite-shear viscosity (Pa·s)
lambda_: Time constant (s)
n: Power-law index
Returns:
Viscosity (Pa·s)
"""
# η(γ̇) = η_∞ + (η_0 - η_∞) [1 + (λγ̇)²]^((n-1)/2)
lambda_gamma = lambda_ * jnp.abs(gamma_dot)
factor = jnp.power(1.0 + lambda_gamma**2, (n - 1.0) / 2.0)
return eta_inf + (eta0 - eta_inf) * factor
@staticmethod
@jax.jit
def _predict_stress(
gamma_dot: jnp.ndarray,
eta0: float,
eta_inf: float,
lambda_: float,
n: float,
) -> jnp.ndarray:
"""Compute shear stress using Carreau model.
Args:
gamma_dot: Shear rate (s^-1)
eta0: Zero-shear viscosity (Pa·s)
eta_inf: Infinite-shear viscosity (Pa·s)
lambda_: Time constant (s)
n: Power-law index
Returns:
Shear stress (Pa)
"""
# σ(γ̇) = η(γ̇) * γ̇
viscosity = Carreau._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, n)
return viscosity * jnp.abs(gamma_dot)
[docs]
def predict_stress(self, gamma_dot: np.ndarray) -> np.ndarray:
"""Predict shear stress for given shear rates.
Args:
gamma_dot: Shear rate data (γ̇)
Returns:
Predicted shear stress σ(γ̇)
"""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
n = self.parameters.get_value("n")
# Convert to JAX for computation
gamma_dot_jax = jnp.array(gamma_dot)
# Compute stress
stress = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, n)
# Convert back to numpy
return np.array(stress)
[docs]
def predict_rheo(
self,
rheo_data: RheoData,
test_mode: TestMode | None = None,
output: str = "viscosity",
) -> RheoData:
"""Predict rheological response for RheoData.
Args:
rheo_data: Input rheological data
test_mode: Test mode (must be ROTATION)
output: Output type ('viscosity' or 'stress')
Returns:
RheoData with predictions
Raises:
ValueError: If test mode is not ROTATION
"""
# Detect test mode if not provided
if test_mode is None:
test_mode = detect_test_mode(rheo_data)
# Validate test mode
if test_mode != TestMode.ROTATION:
raise ValueError(
f"Carreau model only supports ROTATION test mode, got {test_mode}"
)
# Get shear rate data
gamma_dot = rheo_data.x
# Get parameters
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
n = self.parameters.get_value("n")
# Convert to JAX
gamma_dot_jax = jnp.array(gamma_dot)
# Compute prediction based on output type
if output == "viscosity":
y_pred = self._predict_viscosity(gamma_dot_jax, eta0, eta_inf, lambda_, n)
y_units = "Pa·s"
elif output == "stress":
y_pred = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, n)
y_units = "Pa"
else:
raise ValueError(
f"Invalid output type: {output}. Must be 'viscosity' or 'stress'"
)
# Convert back to numpy
y_pred = np.array(y_pred)
# Create output RheoData
return RheoData(
x=np.array(gamma_dot),
y=y_pred,
x_units=rheo_data.x_units or "1/s",
y_units=y_units,
domain="time",
metadata={
"model": "Carreau",
"test_mode": TestMode.ROTATION,
"output": output,
"eta0": eta0,
"eta_inf": eta_inf,
"lambda_": lambda_,
"n": n,
},
validate=False,
)
[docs]
def __repr__(self) -> str:
"""String representation."""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
n = self.parameters.get_value("n")
return f"Carreau(eta0={eta0:.3e}, eta_inf={eta_inf:.3e}, lambda={lambda_:.3e}, n={n:.3f})"
__all__ = ["Carreau"]