"""Cross model for non-Newtonian flow.
This module implements the Cross model, an alternative to the Carreau model
with a different functional form for describing the transition from Newtonian
to power-law behavior (ROTATION test mode).
Theory:
η(γ̇) = η_∞ + (η_0 - η_∞) / [1 + (λγ̇)^m]
- η_0: Zero-shear viscosity (Newtonian plateau at low γ̇)
- η_∞: Infinite-shear viscosity (Newtonian plateau at high γ̇)
- λ: Time constant (characteristic relaxation time)
- m: Rate constant (controls transition steepness)
References:
- Cross, M.M. (1965). J. Colloid Sci. 20, 417-437.
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
import numpy as np
from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger, log_fit
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"cross", protocols=[Protocol.FLOW_CURVE], deformation_modes=[DeformationMode.SHEAR]
)
class Cross(BaseModel):
"""Cross model for non-Newtonian flow (ROTATION only).
The Cross model describes the transition from a Newtonian plateau at
low shear rates to power-law shear-thinning behavior at high shear
rates. It uses a different functional form than the Carreau model
and is often better suited for polymer solutions.
Parameters:
eta0: Zero-shear viscosity (Pa·s), Newtonian plateau at low γ̇
eta_inf: Infinite-shear viscosity (Pa·s), Newtonian plateau at high γ̇
lambda_: Time constant (s), characteristic relaxation time
m: Rate constant (dimensionless), controls transition steepness
Constitutive Equation:
η(γ̇) = η_∞ + (η_0 - η_∞) / [1 + (λγ̇)^m]
Special Cases:
λ → 0: Newtonian fluid with η = η_0
m → 0: Newtonian fluid for all shear rates
λ → ∞: η approaches η_∞
Test Mode:
ROTATION (steady shear) only
"""
[docs]
def __init__(self):
"""Initialize Cross model."""
super().__init__()
self.parameters = ParameterSet()
self.parameters.add(
name="eta0",
value=1000.0,
bounds=(1e-3, 1e12),
units="Pa·s",
description="Zero-shear viscosity",
)
self.parameters.add(
name="eta_inf",
value=0.001,
bounds=(1e-6, 1e6),
units="Pa·s",
description="Infinite-shear viscosity",
)
self.parameters.add(
name="lambda_",
value=1.0,
bounds=(1e-6, 1e6),
units="s",
description="Time constant",
)
self.parameters.add(
name="m",
value=1.0,
bounds=(0.1, 2.0),
units="dimensionless",
description="Rate constant",
)
def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> Cross:
"""Fit Cross parameters to data.
Args:
X: Shear rate data (γ̇)
y: Viscosity data
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
# P3-FLOW-001: Cache test_mode for Bayesian _resolve_test_mode() consistency
self._test_mode = kwargs.get("test_mode", "rotation")
with log_fit(logger, model="Cross", data_shape=X.shape) as ctx:
try:
logger.debug(
"Starting Cross model fit",
n_points=len(X),
gamma_dot_range=(float(np.min(X)), float(np.max(X))),
viscosity_range=(float(np.min(y)), float(np.max(y))),
)
# Simple heuristic fitting
# Sort by shear rate
sort_idx = np.argsort(X)
X_sorted = X[sort_idx]
y_sorted = y[sort_idx]
# Estimate plateaus
eta0_est = np.max(y_sorted[: len(y_sorted) // 10 + 1])
eta_inf_est = np.min(y_sorted[-len(y_sorted) // 10 :])
logger.debug(
"Estimated viscosity plateaus",
eta0_est=float(eta0_est),
eta_inf_est=float(eta_inf_est),
)
# Find characteristic shear rate (midpoint)
eta_mid = (eta0_est + eta_inf_est) / 2.0
idx_mid = np.argmin(np.abs(y_sorted - eta_mid))
lambda_est = 1.0 / X_sorted[idx_mid] if X_sorted[idx_mid] > 0 else 1.0
logger.debug(
"Estimated time constant from midpoint",
eta_mid=float(eta_mid),
lambda_est=float(lambda_est),
)
# Estimate m from transition steepness
# Use the slope in the transition region
mid_start = len(X_sorted) // 3
mid_end = 2 * len(X_sorted) // 3
if mid_end > mid_start + 1:
log_gamma = np.log(np.maximum(X_sorted[mid_start:mid_end], 1e-30))
log_eta = np.log(
np.maximum(
y_sorted[mid_start:mid_end] - eta_inf_est + 1e-10, 1e-30
)
)
coeffs = np.polyfit(log_gamma, log_eta, 1)
m_est = -coeffs[0] # Negative slope gives m
else:
m_est = 1.0
logger.debug(
"Estimated rate constant from slope",
m_est=float(m_est),
)
# Clip to bounds
eta0_est = np.clip(eta0_est, 1e-3, 1e12)
eta_inf_est = np.clip(eta_inf_est, 1e-6, 1e6)
lambda_est = np.clip(lambda_est, 1e-6, 1e6)
m_est = np.clip(m_est, 0.1, 2.0)
# Ensure eta0 > eta_inf
if eta0_est <= eta_inf_est:
eta0_est = eta_inf_est * 10.0
logger.debug(
"Adjusted eta0 to ensure eta0 > eta_inf",
eta0_adjusted=float(eta0_est),
)
self.parameters.set_value("eta0", float(eta0_est))
self.parameters.set_value("eta_inf", float(eta_inf_est))
self.parameters.set_value("lambda_", float(lambda_est))
self.parameters.set_value("m", float(m_est))
# Add fit results to context for logging
ctx["eta0"] = float(eta0_est)
ctx["eta_inf"] = float(eta_inf_est)
ctx["lambda_"] = float(lambda_est)
ctx["m"] = float(m_est)
logger.info(
"Cross model fit completed",
eta0=float(eta0_est),
eta_inf=float(eta_inf_est),
lambda_=float(lambda_est),
m=float(m_est),
)
return self
except Exception as e:
logger.error(
"Cross model fit failed",
error=str(e),
exc_info=True,
)
raise
def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""Predict viscosity for given shear rates.
Args:
X: Shear rate data (γ̇)
Returns:
Predicted viscosity η(γ̇)
"""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
m = self.parameters.get_value("m")
# Convert to JAX for computation
gamma_dot = jnp.array(X)
# Compute viscosity
viscosity = self._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, m)
# Convert back to numpy
return np.array(viscosity)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
Args:
X: Independent variable (shear rate γ̇)
params: Array of parameter values [eta0, eta_inf, lambda_, m]
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
eta0 = params[0]
eta_inf = params[1]
lambda_ = params[2]
m = params[3]
# Flow model only supports ROTATION test mode
# Compute prediction using the internal JAX method
return self._predict_viscosity(X, eta0, eta_inf, lambda_, m)
@staticmethod
@jax.jit
def _predict_viscosity(
gamma_dot: jnp.ndarray,
eta0: float,
eta_inf: float,
lambda_: float,
m: float,
) -> jnp.ndarray:
"""Compute viscosity using Cross model.
Args:
gamma_dot: Shear rate (s^-1)
eta0: Zero-shear viscosity (Pa·s)
eta_inf: Infinite-shear viscosity (Pa·s)
lambda_: Time constant (s)
m: Rate constant
Returns:
Viscosity (Pa·s)
"""
# η(γ̇) = η_∞ + (η_0 - η_∞) / [1 + (λγ̇)^m]
lambda_gamma = lambda_ * jnp.abs(gamma_dot)
denominator = 1.0 + jnp.power(lambda_gamma, m)
return eta_inf + (eta0 - eta_inf) / denominator
@staticmethod
@jax.jit
def _predict_stress(
gamma_dot: jnp.ndarray,
eta0: float,
eta_inf: float,
lambda_: float,
m: float,
) -> jnp.ndarray:
"""Compute shear stress using Cross model.
Args:
gamma_dot: Shear rate (s^-1)
eta0: Zero-shear viscosity (Pa·s)
eta_inf: Infinite-shear viscosity (Pa·s)
lambda_: Time constant (s)
m: Rate constant
Returns:
Shear stress (Pa)
"""
# σ(γ̇) = η(γ̇) * γ̇
viscosity = Cross._predict_viscosity(gamma_dot, eta0, eta_inf, lambda_, m)
return viscosity * jnp.abs(gamma_dot)
[docs]
def predict_stress(self, gamma_dot: np.ndarray) -> np.ndarray:
"""Predict shear stress for given shear rates.
Args:
gamma_dot: Shear rate data (γ̇)
Returns:
Predicted shear stress σ(γ̇)
"""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
m = self.parameters.get_value("m")
# Convert to JAX for computation
gamma_dot_jax = jnp.array(gamma_dot)
# Compute stress
stress = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, m)
# Convert back to numpy
return np.array(stress)
[docs]
def predict_rheo(
self,
rheo_data: RheoData,
test_mode: TestMode | None = None,
output: str = "viscosity",
) -> RheoData:
"""Predict rheological response for RheoData.
Args:
rheo_data: Input rheological data
test_mode: Test mode (must be ROTATION)
output: Output type ('viscosity' or 'stress')
Returns:
RheoData with predictions
Raises:
ValueError: If test mode is not ROTATION
"""
# Detect test mode if not provided
if test_mode is None:
test_mode = detect_test_mode(rheo_data)
# Validate test mode
if test_mode != TestMode.ROTATION:
raise ValueError(
f"Cross model only supports ROTATION test mode, got {test_mode}"
)
# Get shear rate data
gamma_dot = rheo_data.x
# Get parameters
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
m = self.parameters.get_value("m")
# Convert to JAX
gamma_dot_jax = jnp.array(gamma_dot)
# Compute prediction based on output type
if output == "viscosity":
y_pred = self._predict_viscosity(gamma_dot_jax, eta0, eta_inf, lambda_, m)
y_units = "Pa·s"
elif output == "stress":
y_pred = self._predict_stress(gamma_dot_jax, eta0, eta_inf, lambda_, m)
y_units = "Pa"
else:
raise ValueError(
f"Invalid output type: {output}. Must be 'viscosity' or 'stress'"
)
# Convert back to numpy
y_pred = np.array(y_pred)
# Create output RheoData
return RheoData(
x=np.array(gamma_dot),
y=y_pred,
x_units=rheo_data.x_units or "1/s",
y_units=y_units,
domain="time",
metadata={
"model": "Cross",
"test_mode": TestMode.ROTATION,
"output": output,
"eta0": eta0,
"eta_inf": eta_inf,
"lambda_": lambda_,
"m": m,
},
validate=False,
)
[docs]
def __repr__(self) -> str:
"""String representation."""
eta0 = self.parameters.get_value("eta0")
eta_inf = self.parameters.get_value("eta_inf")
lambda_ = self.parameters.get_value("lambda_")
m = self.parameters.get_value("m")
return f"Cross(eta0={eta0:.3e}, eta_inf={eta_inf:.3e}, lambda={lambda_:.3e}, m={m:.3f})"
__all__ = ["Cross"]