"""Specialized pipeline classes for common rheological workflows.
This module provides pre-configured pipelines for standard analysis workflows
like mastercurve construction, model comparison, and data conversion.
Example:
>>> from rheojax.pipeline.workflows import ModelComparisonPipeline
>>> pipeline = ModelComparisonPipeline(['maxwell', 'kelvin_voigt', 'zener'])
>>> pipeline.run(data)
>>> best = pipeline.get_best_model()
"""
from __future__ import annotations
import time
import warnings
from typing import TYPE_CHECKING, Any
import numpy as np
from rheojax.core.data import RheoData
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.logging import get_logger
from rheojax.pipeline.base import Pipeline
if TYPE_CHECKING:
from rheojax.models.spp.spp_yield_stress import SPPYieldStress
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module-level logger
logger = get_logger(__name__)
[docs]
class MastercurvePipeline(Pipeline):
"""Pipeline for time-temperature superposition analysis.
This pipeline automates the construction of mastercurves from
multi-temperature rheological data using horizontal shift factors.
Attributes:
reference_temp: Reference temperature for mastercurve
shift_factors: Dictionary of temperature -> shift factor
Example:
>>> pipeline = MastercurvePipeline(reference_temp=298.15)
>>> pipeline.run(file_paths, temperatures)
>>> mastercurve = pipeline.get_result()
"""
[docs]
def __init__(self, reference_temp: float = 298.15):
"""Initialize mastercurve pipeline.
Args:
reference_temp: Reference temperature in Kelvin (default: 298.15 K)
"""
super().__init__()
self.reference_temp = reference_temp
self.shift_factors: dict[float, float] = {}
[docs]
def run(
self,
file_paths: list[str],
temperatures: list[float],
format: str = "auto",
parallel_io: bool = True,
**load_kwargs,
) -> MastercurvePipeline:
"""Execute mastercurve workflow.
Args:
file_paths: List of data file paths (one per temperature)
temperatures: List of temperatures (in Kelvin)
format: File format for loading
parallel_io: Whether to load files in parallel (default True)
**load_kwargs: Additional arguments passed to load (e.g., x_col, y_col)
Returns:
self for method chaining
Raises:
ValueError: If file_paths and temperatures have different lengths
"""
if len(file_paths) != len(temperatures):
raise ValueError(
f"Number of files ({len(file_paths)}) must match "
f"number of temperatures ({len(temperatures)})"
)
logger.info(
"Starting mastercurve construction",
n_datasets=len(file_paths),
reference_temp=self.reference_temp,
)
start_time = time.perf_counter()
# Load all datasets (optionally in parallel)
if parallel_io and len(file_paths) > 1:
datasets = self._load_datasets_parallel(
file_paths, format=format, **load_kwargs
)
else:
datasets = self._load_datasets_sequential(
file_paths, temperatures, format=format, **load_kwargs
)
# Merge datasets with temperature metadata
merged_data = self._merge_datasets(datasets, temperatures)
# Apply mastercurve transform if available
# For now, we'll implement a simple version
self.data = merged_data
self._apply_mastercurve_shift()
self.history.append(
("mastercurve", str(len(file_paths)), str(self.reference_temp))
)
total_time = time.perf_counter() - start_time
logger.info(
"Mastercurve construction complete",
n_datasets=len(file_paths),
n_shift_factors=len(self.shift_factors),
total_time=total_time,
)
return self
def _load_datasets_sequential(
self,
file_paths: list[str],
temperatures: list[float],
format: str = "auto",
**load_kwargs,
) -> list[RheoData]:
"""Load datasets sequentially."""
datasets = []
for i, file_path in enumerate(file_paths):
dataset_start = time.perf_counter()
try:
temp_pipeline = Pipeline()
temp_pipeline.load(file_path, format=format, **load_kwargs)
datasets.append(temp_pipeline.get_result())
dataset_elapsed = time.perf_counter() - dataset_start
logger.debug(
"Dataset loaded",
dataset=i,
file_path=file_path,
temperature=temperatures[i],
elapsed=dataset_elapsed,
)
except Exception as e:
logger.error(
"Failed to load dataset",
dataset=i,
file_path=file_path,
error=str(e),
exc_info=True,
)
raise
return datasets
def _load_datasets_parallel(
self,
file_paths: list[str],
format: str = "auto",
**load_kwargs,
) -> list[RheoData]:
"""Load datasets in parallel using threads (I/O-safe)."""
from concurrent.futures import ThreadPoolExecutor
def _load_one(file_path: str) -> RheoData:
temp_pipeline = Pipeline()
temp_pipeline.load(file_path, format=format, **load_kwargs)
return temp_pipeline.get_result()
n_workers = min(len(file_paths), 8)
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = [executor.submit(_load_one, fp) for fp in file_paths]
datasets = []
for i, future in enumerate(futures):
try:
data = future.result()
datasets.append(data)
logger.debug(
"Dataset loaded (parallel)",
dataset=i,
file_path=file_paths[i],
)
except Exception as e:
logger.error(
"Failed to load dataset",
dataset=i,
file_path=file_paths[i],
error=str(e),
exc_info=True,
)
raise
return datasets
def _merge_datasets(
self, datasets: list[RheoData], temperatures: list[float]
) -> RheoData:
"""Merge multiple datasets with temperature metadata.
Args:
datasets: List of RheoData objects
temperatures: Corresponding temperatures
Returns:
Merged RheoData
"""
# Add temperature metadata to copies (avoid mutating caller's datasets).
# Guard against None metadata on programmatic RheoData (R10-WF-002).
datasets = [
RheoData(
x=d.x,
y=d.y,
x_units=d.x_units,
y_units=d.y_units,
metadata={**(d.metadata or {}), "temperature": temp},
)
for d, temp in zip(datasets, temperatures, strict=False)
]
# For simplicity, concatenate all data
# In practice, this would be more sophisticated
all_x = np.concatenate([np.array(d.x) for d in datasets])
all_y = np.concatenate([np.array(d.y) for d in datasets])
all_temps = np.concatenate(
[
np.full(len(d.x), temp)
for d, temp in zip(datasets, temperatures, strict=False)
]
)
return RheoData(
x=all_x,
y=all_y,
x_units=datasets[0].x_units,
y_units=datasets[0].y_units,
domain=datasets[0].domain,
metadata={
"type": "mastercurve",
"reference_temp": self.reference_temp,
"temperatures": all_temps.tolist(),
},
validate=False,
)
def _apply_mastercurve_shift(self):
"""Apply horizontal shift to create mastercurve.
This implements a simplified WLF-based shift.
In production, this would use the mastercurve transform.
"""
if self.data is None:
return
temps = np.array((self.data.metadata or {}).get("temperatures", []))
if len(temps) == 0:
return
# Calculate shift factors using simplified WLF equation
# log(a_T) = -C1(T - Tref) / (C2 + T - Tref)
# Using typical values: C1=17.44, C2=51.6
C1, C2 = 17.44, 51.6
for temp in np.unique(temps):
if temp == self.reference_temp:
shift = 1.0
else:
denominator = C2 + temp - self.reference_temp
if abs(denominator) < 1e-10:
logger.warning(
"Temperature at WLF singularity; skipping shift (using 1.0)",
temp=temp,
reference_temp=self.reference_temp,
C2=C2,
)
shift = 1.0
else:
log_shift = -C1 * (temp - self.reference_temp) / denominator
shift = 10**log_shift
self.shift_factors[float(temp)] = shift
# Apply shifts to x data (vectorized per temperature group)
# PIPE-004: Convert to numpy for in-place assignment (JAX arrays are immutable)
shifted_x = np.array(self.data.x)
for unique_temp in np.unique(temps):
mask = temps == unique_temp
shift = self.shift_factors[float(unique_temp)]
# R8-PIPE-001: multiply (not divide) by shift factor for TTS
shifted_x[mask] = shifted_x[mask] * shift
self.data.x = shifted_x
[docs]
def get_shift_factors(self) -> dict[float, float]:
"""Get computed shift factors.
Returns:
Dictionary mapping temperature to shift factor
"""
return self.shift_factors.copy()
def _fit_model_in_subprocess(
model_name: str,
x_data,
y_data,
fit_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Fit a single model in an isolated subprocess.
Module-level function required for pickling on spawn context.
Returns a serializable dict of results (no JAX arrays).
"""
import numpy as np
from rheojax.core.jax_config import safe_import_jax
safe_import_jax()
from rheojax.models import _ensure_all_registered
_ensure_all_registered()
from rheojax.core.registry import ModelRegistry
try:
model = ModelRegistry.create(model_name)
model.fit(x_data, y_data, **fit_kwargs)
y_pred = model.predict(x_data)
# Handle complex/2D predictions
if np.iscomplexobj(y_pred):
y_pred_mag = np.abs(y_pred)
elif y_pred.ndim == 2 and y_pred.shape[1] == 2:
y_pred_mag = np.sqrt(y_pred[:, 0] ** 2 + y_pred[:, 1] ** 2)
else:
y_pred_mag = y_pred
y_for_res = np.abs(y_data) if np.iscomplexobj(y_data) else y_data
residuals = y_for_res - y_pred_mag
rmse = float(np.sqrt(np.mean(residuals**2)))
ss_res = float(np.sum(residuals**2))
y_for_ss = np.abs(y_data) if np.iscomplexobj(y_data) else y_data
ss_tot = float(np.sum((y_for_ss - np.mean(y_for_ss)) ** 2))
r_squared = float(1 - ss_res / ss_tot) if ss_tot > 0 else 0.0
n = len(y_data)
k = len(model.parameters) if hasattr(model, "parameters") else 0
rss = np.sum(residuals**2)
if n > 0 and rss > 0:
aic = float(2 * k + n * np.log(rss / n))
bic = float(k * np.log(n) + n * np.log(rss / n))
elif n > 0:
aic = float("-inf")
bic = float("-inf")
else:
aic = float("inf")
bic = float("inf")
mean_abs_y = np.mean(np.abs(y_data))
rel_rmse = float(rmse / mean_abs_y) if mean_abs_y > 1e-15 else float("inf")
return {
"success": True,
"parameters": model.get_params(),
"predictions": np.asarray(y_pred_mag).tolist(),
"residuals": np.asarray(residuals).tolist(),
"rmse": rmse,
"rel_rmse": rel_rmse,
"r_squared": r_squared,
"n_params": k,
"aic": aic,
"bic": bic,
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
class ModelComparisonPipeline(Pipeline):
"""Pipeline for comparing multiple models on the same data.
This pipeline fits multiple models to the same dataset and
computes comparison metrics (RMSE, R², AIC, etc.).
Attributes:
models: List of model names to compare
results: Dictionary of model_name -> metrics
Example:
>>> pipeline = ModelComparisonPipeline(['maxwell', 'zener', 'springpot'])
>>> pipeline.run(data)
>>> best = pipeline.get_best_model()
>>> print(pipeline.get_comparison_table())
"""
[docs]
def __init__(self, models: list[str]):
"""Initialize model comparison pipeline.
Args:
models: List of model names to compare
"""
super().__init__()
self.models = models
self.results: dict[str, dict[str, Any]] = {}
[docs]
def run(
self,
data: RheoData,
parallel: bool = False,
n_workers: int | None = None,
**fit_kwargs,
) -> ModelComparisonPipeline:
"""Fit multiple models and compare.
Args:
data: RheoData to fit
parallel: Whether to fit models in parallel subprocesses.
Each model gets its own process with independent JIT cache.
n_workers: Number of parallel workers (default: auto)
**fit_kwargs: Additional arguments passed to fit
Returns:
self for method chaining
"""
self.data = data
X = np.array(data.x)
y = np.array(data.y)
# Auto-propagate test_mode, deformation_mode, poisson_ratio from
# data.metadata into fit_kwargs (consistent with Pipeline.fit() and
# BayesianPipeline.fit_nlsq()).
_meta = getattr(data, "metadata", None)
if _meta is not None:
if "test_mode" not in fit_kwargs:
_tm = _meta.get("test_mode")
if _tm is not None:
fit_kwargs["test_mode"] = _tm
if "deformation_mode" not in fit_kwargs:
_dm = _meta.get("deformation_mode")
if _dm is not None:
fit_kwargs["deformation_mode"] = _dm
if "poisson_ratio" not in fit_kwargs:
_pr = _meta.get("poisson_ratio")
if _pr is not None:
fit_kwargs["poisson_ratio"] = _pr
logger.info(
"Starting model comparison",
n_models=len(self.models),
data_shape=X.shape,
)
start_time = time.perf_counter()
if parallel and len(self.models) > 1:
self._run_parallel(X, y, n_workers=n_workers, **fit_kwargs)
total_time = time.perf_counter() - start_time
logger.info(
"Model comparison complete (parallel)",
n_models=len(self.models),
n_successful=len(self.results),
total_time=total_time,
)
return self
for model_name in self.models:
model_start = time.perf_counter()
try:
# Create and fit model
model = ModelRegistry.create(model_name)
model.fit(X, y, **fit_kwargs)
# Generate predictions
y_pred = model.predict(X)
# Handle complex modulus (oscillation mode)
# Case 1: Complex predictions (G* = G' + iG")
if np.iscomplexobj(y_pred):
y_pred_magnitude = np.abs(y_pred)
# Case 2: 2D array [G', G"] format
elif y_pred.ndim == 2 and y_pred.shape[1] == 2:
y_pred_magnitude = np.sqrt(y_pred[:, 0] ** 2 + y_pred[:, 1] ** 2)
# Case 3: Real predictions
else:
y_pred_magnitude = y_pred
# Calculate metrics using magnitude (real values).
# For complex y (oscillation data), compare magnitudes to
# avoid complex residuals that produce wrong RMSE.
y_for_residuals = np.abs(y) if np.iscomplexobj(y) else y
residuals = y_for_residuals - y_pred_magnitude
# Try to use NLSQ result properties (NLSQ 0.6.0 CurveFitResult compatible)
# Falls back to manual computation if result not available
nlsq_result = (
model.get_nlsq_result()
if hasattr(model, "get_nlsq_result")
else None
)
if nlsq_result is not None and nlsq_result.rmse is not None:
# Use NLSQ 0.6.0 CurveFitResult-compatible properties
rmse = nlsq_result.rmse
# Fall back to manual R² when y_data is missing from the result
if nlsq_result.r_squared is not None:
r_squared = nlsq_result.r_squared
else:
ss_res = float(np.sum(residuals**2))
y_for_ss = np.abs(y) if np.iscomplexobj(y) else y
ss_tot = float(np.sum((y_for_ss - np.mean(y_for_ss)) ** 2))
r_squared = float(1 - ss_res / ss_tot) if ss_tot > 0 else 0.0
aic = nlsq_result.aic if nlsq_result.aic is not None else np.inf
bic = nlsq_result.bic if nlsq_result.bic is not None else np.inf
else:
# Fallback: Calculate metrics manually.
# R10-WF-003: use magnitude for ss_tot to avoid TypeError when
# y is complex (oscillation data returns G* = G' + iG'').
rmse = np.sqrt(np.mean(residuals**2))
# Calculate R² manually (avoid calling model.score())
ss_res = float(np.sum(residuals**2))
y_for_ss = np.abs(y) if np.iscomplexobj(y) else y
ss_tot = float(np.sum((y_for_ss - np.mean(y_for_ss)) ** 2))
r_squared = float(1 - ss_res / ss_tot) if ss_tot > 0 else 0.0
# Calculate AIC/BIC manually
n = len(y)
k = len(model.parameters) if hasattr(model, "parameters") else 0
rss = np.sum(residuals**2)
if n > 0 and rss > 0:
aic = 2 * k + n * np.log(rss / n)
bic = k * np.log(n) + n * np.log(rss / n)
elif n > 0 and rss == 0:
aic = -np.inf # Perfect fit
bic = -np.inf
else:
aic = np.inf
bic = np.inf
# Calculate relative RMSE
mean_abs_y = np.mean(np.abs(y))
rel_rmse = rmse / mean_abs_y if mean_abs_y > 1e-15 else np.inf
# Store results
n_params = len(model.parameters) if hasattr(model, "parameters") else 0
self.results[model_name] = {
"model": model,
"parameters": model.get_params(),
"predictions": y_pred_magnitude, # Always real-valued, plottable magnitudes
"residuals": residuals,
"rmse": float(rmse),
"rel_rmse": float(rel_rmse),
"r_squared": float(r_squared),
"n_params": n_params,
"aic": float(aic) if aic is not None else np.inf,
"bic": float(bic) if bic is not None else np.inf,
}
self.history.append(("fit_compare", model_name, str(r_squared)))
model_elapsed = time.perf_counter() - model_start
logger.debug(
"Model fitted",
model=model_name,
r_squared=float(r_squared),
rmse=float(rmse),
elapsed=model_elapsed,
)
except Exception as e:
logger.error(
"Failed to fit model",
model=model_name,
error=str(e),
exc_info=True,
)
warnings.warn(f"Failed to fit model {model_name}: {e}", stacklevel=2)
continue
total_time = time.perf_counter() - start_time
logger.info(
"Model comparison complete",
n_models=len(self.models),
n_successful=len(self.results),
total_time=total_time,
)
return self
def _run_parallel(self, X, y, n_workers=None, **fit_kwargs) -> None:
"""Run model fits in parallel subprocesses."""
from rheojax.parallel.config import get_default_workers
from rheojax.parallel.pool import PersistentProcessPool
n = n_workers or get_default_workers()
x_np = np.asarray(X, dtype=np.float64)
y_np = np.asarray(y)
with PersistentProcessPool(n_workers=n) as pool:
futures = {}
for model_name in self.models:
future = pool.submit(
_fit_model_in_subprocess,
model_name,
x_np,
y_np,
fit_kwargs,
)
futures[model_name] = future
for model_name, future in futures.items():
try:
result = future.result(timeout=300)
if result.get("success"):
self.results[model_name] = {
"parameters": result["parameters"],
"predictions": np.array(result["predictions"]),
"residuals": np.array(result["residuals"]),
"rmse": result["rmse"],
"rel_rmse": result["rel_rmse"],
"r_squared": result["r_squared"],
"n_params": result["n_params"],
"aic": result["aic"],
"bic": result["bic"],
}
self.history.append(
("fit_compare", model_name, str(result["r_squared"]))
)
else:
logger.error(
"Parallel fit failed",
model=model_name,
error=result.get("error", "unknown"),
)
except Exception as e:
logger.error(
"Parallel fit exception",
model=model_name,
error=str(e),
)
[docs]
def get_best_model(self, metric: str = "rmse", minimize: bool = True) -> str:
"""Return name of best-fitting model.
Args:
metric: Metric to use for comparison ('rmse', 'r_squared', 'aic', 'bic')
minimize: If True, lower values are better (e.g., RMSE, AIC, BIC)
Returns:
Name of best model
Example:
>>> best = pipeline.get_best_model(metric='aic')
"""
if not self.results:
raise ValueError("No models fitted. Call run() first.")
if minimize:
return min(self.results.items(), key=lambda x: x[1].get(metric, np.inf))[0]
else:
return max(self.results.items(), key=lambda x: x[1].get(metric, -np.inf))[0]
[docs]
def get_comparison_table(self) -> dict[str, dict[str, float]]:
"""Get comparison table of all models.
Returns:
Dictionary of model_name -> metrics
Example:
>>> table = pipeline.get_comparison_table()
>>> for model, metrics in table.items():
... print(f"{model}: R²={metrics['r_squared']:.4f}")
"""
return {
name: {
"rmse": result["rmse"],
"rel_rmse": result["rel_rmse"],
"r_squared": result["r_squared"],
"aic": result.get("aic", np.nan),
"bic": result.get("bic", np.nan),
"n_params": result["n_params"],
}
for name, result in self.results.items()
}
[docs]
def get_model_result(self, model_name: str) -> dict[str, Any]:
"""Get detailed results for a specific model.
Args:
model_name: Name of the model
Returns:
Dictionary with model, parameters, and metrics
Example:
>>> result = pipeline.get_model_result('maxwell')
>>> params = result['parameters']
"""
if model_name not in self.results:
raise KeyError(f"Model {model_name} not in results")
return self.results[model_name]
[docs]
class CreepToRelaxationPipeline(Pipeline):
"""Convert creep compliance data to relaxation modulus.
This pipeline performs the numerical conversion from J(t) to G(t)
using regularized numerical inversion techniques.
Example:
>>> pipeline = CreepToRelaxationPipeline()
>>> pipeline.run(creep_data)
>>> relaxation_data = pipeline.get_result()
"""
[docs]
def run(
self, creep_data: RheoData, method: str = "approximate"
) -> CreepToRelaxationPipeline:
"""Execute conversion workflow.
Args:
creep_data: RheoData with creep compliance J(t)
method: Conversion method ('approximate', 'exact')
Returns:
self for method chaining
Raises:
ValueError: If input is not creep data
"""
self.data = creep_data
logger.info(
"Starting creep to relaxation conversion",
method=method,
data_points=len(creep_data.x),
)
start_time = time.perf_counter()
# R8-PIPE-004: guard against None metadata on programmatic RheoData
metadata = getattr(creep_data, "metadata", None) or {}
# Validate test mode
test_mode = metadata.get("test_mode", "").lower()
if test_mode and test_mode != "creep":
warnings.warn(
f"Input appears to be {test_mode} data, not creep. "
"Results may be inaccurate.",
stacklevel=2,
)
try:
if method == "approximate":
self._approximate_conversion()
elif method == "exact":
self._exact_conversion()
else:
raise ValueError(f"Unknown method: {method}")
self.history.append(("creep_to_relaxation", method))
total_time = time.perf_counter() - start_time
logger.info(
"Creep to relaxation conversion complete",
method=method,
total_time=total_time,
)
except Exception as e:
logger.error(
"Creep to relaxation conversion failed",
method=method,
error=str(e),
exc_info=True,
)
raise
return self
def _approximate_conversion(self):
"""Apply approximate conversion G(t) ≈ 1/J(t).
This is valid for small strains and elastic-dominant materials.
"""
if self.data is None:
return
J_t = np.array(self.data.y)
# Avoid division by zero
J_t = np.maximum(J_t, 1e-20)
G_t = 1.0 / J_t
self.data = RheoData(
x=self.data.x,
y=G_t,
x_units=self.data.x_units,
y_units="Pa" if not self.data.y_units else self.data.y_units,
domain=self.data.domain,
metadata={
**(self.data.metadata or {}),
"test_mode": "relaxation",
"conversion_method": "approximate",
},
validate=False,
)
def _exact_conversion(self):
"""Apply exact conversion using Laplace transform inversion.
This is more accurate but computationally intensive.
For now, we use a simplified numerical approach.
"""
if self.data is None:
return
# This would use a proper Laplace transform inversion
# For now, fall back to approximate
warnings.warn(
"Exact conversion not fully implemented. Using approximate method.",
stacklevel=2,
)
self._approximate_conversion()
self.data.metadata["conversion_method"] = "exact_approximate"
[docs]
class FrequencyToTimePipeline(Pipeline):
"""Convert frequency domain data to time domain.
This pipeline converts dynamic modulus G*(ω) to relaxation modulus G(t)
using Fourier transform techniques.
Example:
>>> pipeline = FrequencyToTimePipeline()
>>> pipeline.run(frequency_data)
>>> time_data = pipeline.get_result()
"""
[docs]
def run(
self,
frequency_data: RheoData,
time_range: tuple | None = None,
n_points: int = 100,
) -> FrequencyToTimePipeline:
"""Execute frequency to time conversion.
Args:
frequency_data: RheoData in frequency domain
time_range: Optional (t_min, t_max) for time range
n_points: Number of time points to generate
Returns:
self for method chaining
"""
self.data = frequency_data
logger.info(
"Starting frequency to time conversion",
n_points=n_points,
input_points=len(frequency_data.x),
)
start_time = time.perf_counter()
if frequency_data.domain != "frequency":
warnings.warn("Input data may not be in frequency domain", stacklevel=2)
try:
# Generate time points
if time_range is None:
# Auto-generate from frequency range
w_min = np.min(np.array(frequency_data.x))
w_max = np.max(np.array(frequency_data.x))
if w_min <= 0 or w_max <= 0:
raise ValueError(
"Frequency data must be positive for time conversion"
)
t_min = 1.0 / w_max
t_max = 1.0 / w_min
else:
t_min, t_max = time_range
t = np.logspace(np.log10(t_min), np.log10(t_max), n_points)
# Simplified conversion using inverse Fourier transform approximation
# In practice, this would use proper numerical FFT
omega = np.array(frequency_data.x)
G_star = np.array(frequency_data.y)
# Placeholder: proper implementation would use FFT
# For now, use simple numerical integration
G_t = self._approximate_inverse_transform(t, omega, G_star)
self.data = RheoData(
x=t,
y=G_t,
x_units="s",
y_units=frequency_data.y_units,
domain="time",
metadata={
**(frequency_data.metadata or {}),
"conversion": "frequency_to_time",
"original_domain": "frequency",
},
validate=False,
)
self.history.append(("frequency_to_time", str(n_points)))
total_time = time.perf_counter() - start_time
logger.info(
"Frequency to time conversion complete",
n_points=n_points,
total_time=total_time,
)
except Exception as e:
logger.error(
"Frequency to time conversion failed",
error=str(e),
exc_info=True,
)
raise
return self
def _approximate_inverse_transform(
self, t: np.ndarray, omega: np.ndarray, G_star: np.ndarray
) -> np.ndarray:
"""Inverse Fourier transform from G*(ω) to G(t).
Uses numerical integration of the inverse Fourier transform:
G(t) = (2/π) ∫ G'(ω) cos(ωt) dω
Args:
t: Time points
omega: Angular frequency points
G_star: Complex modulus (G' + iG'' or just G')
Returns:
Relaxation modulus at time points
"""
from scipy.integrate import trapezoid
# Extract real part (storage modulus G')
if np.iscomplexobj(G_star):
G_prime = np.real(G_star)
elif G_star.ndim == 2 and G_star.shape[1] == 2:
G_prime = G_star[:, 0]
else:
G_prime = G_star
# Sort by frequency for proper integration
sort_idx = np.argsort(omega)
omega_sorted = omega[sort_idx]
G_prime_sorted = G_prime[sort_idx]
# Compute G(t) via numerical integration of inverse transform
G_t = np.zeros_like(t)
for i, t_i in enumerate(t):
# G(t) = (2/π) ∫ G'(ω) cos(ωt) dω
integrand = G_prime_sorted * np.cos(omega_sorted * t_i)
G_t[i] = (2.0 / np.pi) * trapezoid(integrand, omega_sorted)
# Ensure non-negative (physical constraint)
G_t = np.maximum(G_t, 0.0)
return G_t
[docs]
class SPPAmplitudeSweepPipeline(Pipeline):
"""Pipeline for SPP analysis of amplitude sweep LAOS data.
This pipeline performs SPP (Sequence of Physical Processes) analysis
on amplitude sweep LAOS data to extract yield stress parameters and
nonlinear viscoelastic metrics.
Workflow:
1. Load amplitude sweep data (multiple γ_0 values)
2. Apply SPP decomposition at each amplitude
3. Extract yield stresses (static and dynamic)
4. Fit power-law scaling to yield stress vs amplitude
5. Optionally fit Bayesian SPPYieldStress model
Attributes:
omega: Angular frequency of oscillation (rad/s)
results: Dictionary of SPP metrics per amplitude
model: Fitted SPPYieldStress model (after fit_model)
Example:
>>> pipeline = SPPAmplitudeSweepPipeline(omega=1.0)
>>> pipeline.run(amplitude_data_list)
>>> pipeline.fit_model(bayesian=True)
>>> print(pipeline.get_yield_stresses())
"""
[docs]
def __init__(
self,
omega: float = 1.0,
n_harmonics: int = 39,
step_size: int = 8,
num_mode: int = 2,
wrap_strain_rate: bool = True,
use_numerical_method: bool | None = None,
):
"""Initialize SPP amplitude sweep pipeline.
Args:
omega: Angular frequency in rad/s (default: 1.0)
n_harmonics: Number of harmonics for SPP decomposition (default: 39)
step_size: Differentiation step size k (default: 8, Rogers parity)
num_mode: Numerical differentiation mode (default: 2 periodic)
wrap_strain_rate: Whether to use wrapped differentiation when rate missing
use_numerical_method: Force numerical path; None keeps default from transform
"""
super().__init__()
self.omega = omega
self.n_harmonics = n_harmonics
self.step_size = step_size
self.num_mode = num_mode
self.wrap_strain_rate = wrap_strain_rate
self.use_numerical_method = use_numerical_method
self.results: dict[float, dict] = {} # gamma_0 -> SPP results
self.model: SPPYieldStress | None = None
self._gamma_0_values: list[float] = []
self._sigma_sy_values: list[float] = []
self._sigma_dy_values: list[float] = []
[docs]
def run(
self,
stress_data: list[RheoData],
gamma_0_values: list[float] | None = None,
) -> SPPAmplitudeSweepPipeline:
"""Execute SPP analysis on amplitude sweep data.
Args:
stress_data: List of RheoData objects, one per amplitude
gamma_0_values: Strain amplitudes corresponding to each dataset.
If None, extracted from RheoData metadata.
Returns:
self for method chaining
Raises:
ValueError: If gamma_0_values not provided and not in metadata
"""
from rheojax.transforms.spp_decomposer import SPPDecomposer
# Extract gamma_0 values if not provided
if gamma_0_values is None:
gamma_0_values = []
for data in stress_data:
_g0_meta = data.metadata or {}
if "gamma_0" in _g0_meta:
gamma_0_values.append(_g0_meta["gamma_0"])
else:
raise ValueError(
"gamma_0_values must be provided or present in metadata"
)
if len(stress_data) != len(gamma_0_values):
raise ValueError(
f"Number of datasets ({len(stress_data)}) must match "
f"number of amplitudes ({len(gamma_0_values)})"
)
logger.info(
"Starting SPP amplitude sweep analysis",
n_datasets=len(stress_data),
omega=self.omega,
)
start_time = time.perf_counter()
n_successful = 0
# Process each amplitude
for i, (gamma_0, data) in enumerate(
zip(gamma_0_values, stress_data, strict=False)
):
amplitude_start = time.perf_counter()
# Ensure required metadata is present for downstream transforms/models
if data.metadata is None:
data.metadata = {}
data.metadata.setdefault("test_mode", "oscillation")
data.metadata.setdefault("gamma_0", gamma_0)
data.metadata.setdefault("omega", self.omega)
# Apply SPP decomposition
decomposer = SPPDecomposer(
omega=self.omega,
gamma_0=gamma_0,
n_harmonics=self.n_harmonics,
step_size=self.step_size,
num_mode=self.num_mode,
wrap_strain_rate=self.wrap_strain_rate,
use_numerical_method=(
self.use_numerical_method
if self.use_numerical_method is not None
else False
),
)
try:
decomposer.transform(data)
results = decomposer.get_results()
self.results[float(gamma_0)] = results
self._gamma_0_values.append(float(gamma_0))
self._sigma_sy_values.append(results["sigma_sy"])
self._sigma_dy_values.append(results["sigma_dy"])
self.history.append(("spp_analyze", str(gamma_0), "success"))
n_successful += 1
amplitude_elapsed = time.perf_counter() - amplitude_start
logger.debug(
"SPP decomposition completed",
dataset=i,
gamma_0=gamma_0,
sigma_sy=results["sigma_sy"],
sigma_dy=results["sigma_dy"],
elapsed=amplitude_elapsed,
)
except Exception as e:
logger.error(
"SPP analysis failed",
dataset=i,
gamma_0=gamma_0,
error=str(e),
exc_info=True,
)
warnings.warn(
f"SPP analysis failed at γ_0 = {gamma_0}: {e}", stacklevel=2
)
self.history.append(("spp_analyze", str(gamma_0), f"failed: {e}"))
# Sort by amplitude
sort_idx = np.argsort(self._gamma_0_values)
self._gamma_0_values = [self._gamma_0_values[i] for i in sort_idx]
self._sigma_sy_values = [self._sigma_sy_values[i] for i in sort_idx]
self._sigma_dy_values = [self._sigma_dy_values[i] for i in sort_idx]
total_time = time.perf_counter() - start_time
logger.info(
"SPP amplitude sweep analysis complete",
n_datasets=len(stress_data),
n_successful=n_successful,
total_time=total_time,
)
return self
[docs]
def fit_model(
self,
bayesian: bool = False,
yield_type: str = "static",
**fit_kwargs,
) -> SPPAmplitudeSweepPipeline:
"""Fit SPPYieldStress model to extracted yield stresses.
Args:
bayesian: Whether to use Bayesian inference (default: False)
yield_type: Which yield stress to fit ('static' or 'dynamic')
**fit_kwargs: Additional arguments passed to fit or fit_bayesian
Returns:
self for method chaining
"""
from rheojax.models.spp.spp_yield_stress import SPPYieldStress
if not self._gamma_0_values:
raise RuntimeError("No data available. Call run() first.")
logger.info(
"Starting SPP model fitting",
bayesian=bayesian,
yield_type=yield_type,
n_points=len(self._gamma_0_values),
)
start_time = time.perf_counter()
gamma_0_array = np.array(self._gamma_0_values)
if yield_type == "static":
sigma_array = np.array(self._sigma_sy_values)
else:
sigma_array = np.array(self._sigma_dy_values)
# R6-PIPE-001: Only reuse warm-start if yield_type matches the prior fit
_prior_yield_type = getattr(self, "_last_fit_yield_type", None)
if (
bayesian
and hasattr(self, "model")
and self.model is not None
and self.model.fitted_
and _prior_yield_type == yield_type
):
pass # Reuse NLSQ-fitted model for warm-start
else:
self.model = SPPYieldStress()
try:
if bayesian:
self.model.fit_bayesian(
gamma_0_array,
sigma_array,
test_mode="oscillation",
**fit_kwargs,
)
self.history.append(("fit_bayesian", yield_type, "complete"))
else:
self.model.fit(
gamma_0_array,
sigma_array,
test_mode="oscillation",
yield_type=yield_type,
**fit_kwargs,
)
self.history.append(("fit_nlsq", yield_type, "complete"))
self._last_fit_yield_type = yield_type
total_time = time.perf_counter() - start_time
logger.info(
"SPP model fitting complete",
bayesian=bayesian,
yield_type=yield_type,
total_time=total_time,
)
except Exception as e:
logger.error(
"SPP model fitting failed",
bayesian=bayesian,
yield_type=yield_type,
error=str(e),
exc_info=True,
)
raise
return self
[docs]
def get_yield_stresses(self) -> dict[str, np.ndarray]:
"""Get extracted yield stresses from amplitude sweep.
Returns:
Dictionary with:
- gamma_0: strain amplitudes
- sigma_sy: static yield stresses
- sigma_dy: dynamic yield stresses
"""
return {
"gamma_0": np.array(self._gamma_0_values),
"sigma_sy": np.array(self._sigma_sy_values),
"sigma_dy": np.array(self._sigma_dy_values),
}
[docs]
def get_amplitude_results(self, gamma_0: float) -> dict:
"""Get full SPP results for a specific amplitude.
Args:
gamma_0: Strain amplitude to retrieve
Returns:
Dictionary of SPP metrics for that amplitude
Raises:
KeyError: If amplitude not in results
"""
if gamma_0 not in self.results:
raise KeyError(f"No results for γ_0 = {gamma_0}")
return self.results[gamma_0].copy()
[docs]
def get_model(self) -> Any:
"""Get fitted SPPYieldStress model.
Returns:
Fitted model or None if not fitted
"""
return self.model
[docs]
def get_nonlinearity_metrics(self) -> dict[float, dict]:
"""Get nonlinearity metrics (I3/I1, S, T) for each amplitude.
Returns:
Dictionary mapping gamma_0 to nonlinearity metrics
"""
return {
gamma_0: {
"I3_I1_ratio": results.get("I3_I1_ratio", 0.0),
"S_factor": results.get("S_factor", 0.0),
"T_factor": results.get("T_factor", 0.0),
}
for gamma_0, results in self.results.items()
}
__all__ = [
"MastercurvePipeline",
"ModelComparisonPipeline",
"CreepToRelaxationPipeline",
"FrequencyToTimePipeline",
"SPPAmplitudeSweepPipeline",
]