Source code for rheojax.transforms.mastercurve

"""Time-Temperature Superposition (TTS) mastercurve generation.

This module implements time-temperature superposition for creating mastercurves
from multi-temperature rheological data using WLF or Arrhenius shift factors,
or automatic shift factor calculation via power-law intersection (pyvisco algorithm).
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import numpy as np

from rheojax.core.base import BaseTransform
from rheojax.core.data import RheoData
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger, log_transform
from rheojax.utils.optimization import create_least_squares_objective, nlsq_optimize

# Module logger
logger = get_logger(__name__)

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

if TYPE_CHECKING:
    import jax.numpy as jnp_typing
else:  # pragma: no cover - typing fallback
    jnp_typing = np

type JaxArray = jnp_typing.ndarray
type ScalarOrArray = float | JaxArray


ShiftMethod = Literal["wlf", "arrhenius", "manual"]


[docs] @TransformRegistry.register("mastercurve", type=TransformType.SUPERPOSITION) class Mastercurve(BaseTransform): """Time-Temperature Superposition (TTS) mastercurve generation. This transform applies time-temperature superposition to create mastercurves from multi-temperature rheological data. Supports both WLF and Arrhenius shift factor models for horizontal shifting, with optional vertical shifting, or automatic shift factor calculation via power-law intersection method. The WLF equation is: log(a_T) = -C1 * (T - T_ref) / (C2 + (T - T_ref)) The Arrhenius equation is: log(a_T) = (E_a / R) * (1/T - 1/T_ref) Automatic shift factors use power-law intersection (pyvisco algorithm): Fits each curve to y = a*x^b + e, then computes shift from intersection Parameters ---------- reference_temp : float, default=298.15 Reference temperature in Kelvin method : ShiftMethod, default='wlf' Shift factor method: 'wlf', 'arrhenius', or 'manual' C1 : float, default=17.44 WLF parameter C1 (universal value for polymers) C2 : float, default=51.6 WLF parameter C2 in Kelvin (universal value) E_a : float, optional Activation energy for Arrhenius (J/mol) vertical_shift : bool, default=False Whether to apply vertical shifting (for modulus scaling) optimize_shifts : bool, default=True Whether to optimize shift factors to minimize overlap error auto_shift : bool, default=False Whether to use automatic shift factor calculation via power-law intersection. If True, overrides manual WLF/Arrhenius calculations. Examples -------- >>> from rheojax.core.data import RheoData >>> from rheojax.transforms.mastercurve import Mastercurve >>> >>> # Create multi-temperature frequency sweep data >>> # (In practice, this would come from experimental measurements) >>> temps = [273, 298, 323] # K >>> freq = jnp.logspace(-2, 2, 50) >>> datasets = [] >>> for T in temps: ... G_prime = some_modulus_function(freq, T) ... data = RheoData(x=freq, y=G_prime, domain='frequency', ... metadata={'temperature': T}) ... datasets.append(data) >>> >>> # Create mastercurve at reference temperature (two equivalent APIs) >>> mc = Mastercurve(reference_temp=298.15, method='wlf') >>> >>> # Option 1: Using create_mastercurve (explicit) >>> mastercurve = mc.create_mastercurve(datasets) >>> >>> # Option 2: Using transform with list (returns shift factors too) >>> mastercurve, shift_factors = mc.transform(datasets) >>> print(shift_factors) # {273.0: 42.5, 298.15: 1.0, 323.0: 0.024} >>> >>> # Option 3: Automatic shift factor calculation >>> mc_auto = Mastercurve(reference_temp=298.15, auto_shift=True) >>> mastercurve_auto, shifts_auto = mc_auto.transform(datasets) """
[docs] def __init__( self, reference_temp: float = 298.15, method: ShiftMethod = "wlf", C1: float = 17.44, C2: float = 51.6, E_a: float | None = None, vertical_shift: bool = False, optimize_shifts: bool = True, auto_shift: bool = False, ): """Initialize Mastercurve transform. Parameters ---------- reference_temp : float Reference temperature in Kelvin method : ShiftMethod Shift factor method C1 : float WLF parameter C1 C2 : float WLF parameter C2 (Kelvin) E_a : float, optional Activation energy for Arrhenius (J/mol) vertical_shift : bool Apply vertical shifting optimize_shifts : bool Optimize shift factors auto_shift : bool Use automatic power-law intersection for shift calculation """ super().__init__() self.T_ref = reference_temp self.method = method self.C1 = C1 self.C2 = C2 self.E_a = E_a self.vertical_shift = vertical_shift self.optimize_shifts = optimize_shifts self._auto_shift = auto_shift # Store computed shift factors self.shift_factors_: dict[float, float] | None = None self.vertical_shifts_: dict[float, float] | None = None self._auto_shift_factors: np.ndarray | None = None
def _calculate_wlf_shift( self, T: ScalarOrArray, T_ref: float, C1: float, C2: float ) -> ScalarOrArray: """Calculate WLF shift factor. Parameters ---------- T : float or jnp.ndarray Temperature(s) in Kelvin T_ref : float Reference temperature in Kelvin C1 : float WLF parameter C1 C2 : float WLF parameter C2 (Kelvin) Returns ------- float or jnp.ndarray Shift factor a_T """ # WLF equation: log(a_T) = -C1(T-T_ref)/(C2+(T-T_ref)) denominator = C2 + (T - T_ref) # Guard against division by zero at WLF singularity (T = T_ref - C2) safe_denom = jnp.where(jnp.abs(denominator) < 1e-12, 1.0, denominator) log_aT = jnp.where( jnp.abs(denominator) < 1e-12, 0.0, -C1 * (T - T_ref) / safe_denom, ) return jnp.power(10.0, log_aT) def _calculate_arrhenius_shift( self, T: ScalarOrArray, T_ref: float, E_a: float ) -> ScalarOrArray: """Calculate Arrhenius shift factor. Parameters ---------- T : float or jnp.ndarray Temperature(s) in Kelvin T_ref : float Reference temperature in Kelvin E_a : float Activation energy (J/mol) Returns ------- float or jnp.ndarray Shift factor a_T """ R = 8.314 # Gas constant (J/(mol·K)) # Arrhenius: log(a_T) = (E_a/R) * (1/T - 1/T_ref) log_aT = (E_a / R) * (1.0 / T - 1.0 / T_ref) return jnp.exp(log_aT) def _fit_power_law( self, x: np.ndarray, y: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Fit power-law model: y = a*x^b + e using NLSQ. Parameters ---------- x : ndarray X data (frequency or time) y : ndarray Y data (modulus or other response) Returns ------- popt : ndarray Optimal parameters [a, b, e] perr : ndarray Parameter uncertainties (standard errors) """ logger.debug( "Fitting power-law model", n_points=len(x), x_range=(float(x.min()), float(x.max())), ) # Create parameter set with reasonable bounds params = ParameterSet() params.add("a", value=1.0, bounds=(1e-10, 1e10)) params.add("b", value=-0.5, bounds=(-5.0, 5.0)) params.add("e", value=0.0, bounds=(-1e10, 1e10)) # Define model function for power-law def power_law_model(x_data: np.ndarray, param_values: np.ndarray) -> np.ndarray: """Power-law model: y = a*x^b + e.""" a, b, e = param_values return a * jnp.power(x_data, b) + e # Create least-squares objective objective = create_least_squares_objective(power_law_model, x, y) # Optimize using NLSQ result = nlsq_optimize( objective, params, use_jax=True, max_iter=1000, ftol=1e-8, xtol=1e-8 ) # Extract parameter uncertainties from Jacobian if result.jac is not None: # Estimate covariance from Jacobian: Cov ≈ (J^T J)^-1 try: jtj = result.jac.T @ result.jac cov = np.linalg.inv(jtj) perr = np.sqrt(np.maximum(np.diag(cov), 0.0)) except np.linalg.LinAlgError: # Singular matrix, use large uncertainties logger.debug("Jacobian singular, using large uncertainties") perr = np.full(3, 1e6) else: perr = np.full(3, 1e6) logger.debug( "Power-law fit completed", a=float(result.x[0]), b=float(result.x[1]), e=float(result.x[2]), ) return result.x, perr def _detect_outliers( self, x: np.ndarray, y: np.ndarray, popt_full: np.ndarray, perr_full: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Detect and remove outliers (first point) if it improves fit. Following pyvisco algorithm: try removing first point, keep removal if exponent error improves. Parameters ---------- x : ndarray X data y : ndarray Y data popt_full : ndarray Parameters from full fit [a, b, e] perr_full : ndarray Uncertainties from full fit Returns ------- x_clean : ndarray Cleaned x data y_clean : ndarray Cleaned y data popt_clean : ndarray Parameters from cleaned fit perr_clean : ndarray Uncertainties from cleaned fit """ # Need at least 2 points remaining after removing the first to attempt a fit. # With only 1 point we cannot improve the fit; return the full-data result. x_no_first = x[1:] y_no_first = y[1:] if len(x_no_first) < 3: # Too few points for 3-parameter power-law fit; keep full data logger.debug( "Too few points to attempt outlier removal (need >= 4 total)", n_points=len(x), ) return x, y, popt_full, perr_full popt_no_first, perr_no_first = self._fit_power_law(x_no_first, y_no_first) # Compare exponent (b) uncertainty if perr_no_first[1] < perr_full[1]: # Removing first point improves fit — log truncation (T-006) logger.warning( "First data point removed during power-law fit " "(improved exponent uncertainty)", original_points=len(x), removed_x_value=float(x[0]), removed_y_value=float(y[0]), ) return x_no_first, y_no_first, popt_no_first, perr_no_first else: # Keep all points return x, y, popt_full, perr_full def _compute_pairwise_shift( self, curve_top: np.ndarray, curve_bot: np.ndarray, popt_top: np.ndarray, popt_bot: np.ndarray, ) -> float: """Compute shift factor between two adjacent curves via intersection. Uses power-law intersection method: sample points in overlap/gap region, compute inverse power-law to find x-shift, average log(aT). Parameters ---------- curve_top : ndarray Top curve data [x, y] with shape (N, 2) curve_bot : ndarray Bottom curve data [x, y] with shape (M, 2) popt_top : ndarray Power-law parameters [a, b, e] for top curve popt_bot : ndarray Power-law parameters [a, b, e] for bottom curve Returns ------- log_aT : float Log10 shift factor """ # Extract x and y ranges x_top, y_top = curve_top[:, 0], curve_top[:, 1] x_bot, y_bot = curve_bot[:, 0], curve_bot[:, 1] # Determine overlap or gap y_min_top, y_max_top = y_top.min(), y_top.max() y_min_bot, y_max_bot = y_bot.min(), y_bot.max() # Find y-range for intersection sampling if y_min_top < y_max_bot and y_max_bot < y_max_top: # Overlap case y_sample_min = max(y_min_top, y_min_bot) y_sample_max = min(y_max_top, y_max_bot) else: # Gap case: sample in gap region y_sample_min = min(y_max_bot, y_max_top) y_sample_max = max(y_min_bot, y_min_top) # Sample 10 points in y-range y_samples = np.linspace(y_sample_min, y_sample_max, 10) # Compute x from inverse power-law: x = ((y - e) / a)^(1/b) a_top, b_top, e_top = popt_top a_bot, b_bot, e_bot = popt_bot # Inverse power-law for top curve (guard against a≈0, negative base, b≈0) a_top_safe = np.sign(a_top) * max(abs(a_top), 1e-20) base_top = np.maximum((y_samples - e_top) / a_top_safe, 1e-20) b_top_safe = np.sign(b_top) * max(abs(b_top), 1e-10) x_top_inv = np.power(base_top, 1.0 / b_top_safe) # Inverse power-law for bottom curve (guard against a≈0, negative base, b≈0) a_bot_safe = np.sign(a_bot) * max(abs(a_bot), 1e-20) base_bot = np.maximum((y_samples - e_bot) / a_bot_safe, 1e-20) b_bot_safe = np.sign(b_bot) * max(abs(b_bot), 1e-10) x_bot_inv = np.power(base_bot, 1.0 / b_bot_safe) # Compute log shift factors and average log_shift_factors = np.log10(x_top_inv / x_bot_inv) # Handle infinities and NaNs valid = np.isfinite(log_shift_factors) if np.sum(valid) == 0: # Fallback: use geometric mean of x-ranges return float(np.log10(np.mean(x_top) / np.mean(x_bot))) log_aT = float(np.mean(log_shift_factors[valid])) return log_aT def _compute_auto_shift_factors( self, datasets: list[RheoData], ref_temp_idx: int ) -> np.ndarray: """Compute automatic shift factors via sequential pairwise power-law intersection. Follows pyvisco algorithm: 1. Fit power-law to each temperature curve 2. Detect and remove outliers 3. Compute pairwise shifts via intersection 4. Accumulate shifts sequentially from reference temperature Parameters ---------- datasets : list of RheoData Multi-temperature datasets (must be sorted by temperature) ref_temp_idx : int Index of reference temperature in datasets Returns ------- log_aT_array : ndarray Cumulative log10 shift factors for all temperatures """ n_temps = len(datasets) logger.debug( "Computing automatic shift factors", n_temperatures=n_temps, ref_temp_idx=ref_temp_idx, ) log_aT_array = np.zeros(n_temps) # Fit power-law to each curve with outlier detection power_law_params = [] curves = [] for i, data in enumerate(datasets): x = np.asarray(data.x, dtype=np.float64) y = np.asarray(data.y, dtype=np.float64) logger.debug( "Fitting power-law for temperature curve", curve_index=i, n_points=len(x), ) # Fit power-law popt, perr = self._fit_power_law(x, y) # Detect outliers x_clean, y_clean, popt_clean, perr_clean = self._detect_outliers( x, y, popt, perr ) power_law_params.append(popt_clean) curves.append(np.column_stack([x_clean, y_clean])) # Sequential cumulative shifting below reference temperature for i in range(ref_temp_idx - 1, -1, -1): # Shift from i+1 to i (going down in temperature) log_shift = self._compute_pairwise_shift( curves[i + 1], curves[i], power_law_params[i + 1], power_law_params[i] ) log_aT_array[i] = log_aT_array[i + 1] + log_shift logger.debug( "Pairwise shift computed", from_idx=i + 1, to_idx=i, log_shift=float(log_shift), ) # Sequential cumulative shifting above reference temperature for i in range(ref_temp_idx + 1, n_temps): # Shift from i-1 to i (going up in temperature) log_shift = self._compute_pairwise_shift( curves[i - 1], curves[i], power_law_params[i - 1], power_law_params[i] ) log_aT_array[i] = log_aT_array[i - 1] + log_shift logger.debug( "Pairwise shift computed", from_idx=i - 1, to_idx=i, log_shift=float(log_shift), ) # Store for later retrieval self._auto_shift_factors = log_aT_array logger.debug( "Auto shift factors computed", shift_factors=log_aT_array.tolist(), ) return log_aT_array
[docs] def get_auto_shift_factors(self) -> tuple[np.ndarray, np.ndarray] | None: """Get automatic shift factors as arrays for plotting. Returns ------- temperatures : ndarray or None Array of temperatures in Kelvin, or None if not computed log_aT : ndarray or None Array of log10 shift factors, or None if not computed """ if self._auto_shift_factors is None: return None if self.shift_factors_ is None: return None temps = np.array(sorted(self.shift_factors_.keys())) return temps, self._auto_shift_factors
[docs] def get_shift_factor(self, T: float) -> float: """Get shift factor for a given temperature. Parameters ---------- T : float Temperature in Kelvin Returns ------- float Horizontal shift factor a_T """ if self.method == "wlf": return float(self._calculate_wlf_shift(T, self.T_ref, self.C1, self.C2)) elif self.method == "arrhenius": if self.E_a is None: raise ValueError("E_a must be provided for Arrhenius method") return float(self._calculate_arrhenius_shift(T, self.T_ref, self.E_a)) elif self.method == "manual": if self.shift_factors_ is None: raise ValueError("Manual shift factors not set") return self.shift_factors_.get(T, 1.0) else: raise ValueError(f"Unknown shift method: {self.method}")
[docs] def set_manual_shifts(self, shift_factors: dict[float, float]): """Set manual shift factors for each temperature. Parameters ---------- shift_factors : dict Dictionary mapping temperature (K) to shift factor """ self.method = "manual" self.shift_factors_ = shift_factors
[docs] def get_wlf_parameters(self) -> dict[str, float]: """Get WLF parameters. Returns ------- dict Dictionary with keys 'C1', 'C2', and 'T_ref' (reference temperature) Raises ------ ValueError If method is not 'wlf' """ if self.method != "wlf": raise ValueError(f"WLF parameters not available for method '{self.method}'") return { "C1": self.C1, "C2": self.C2, "T_ref": self.T_ref, }
[docs] def get_arrhenius_parameters(self) -> dict[str, float]: """Get Arrhenius parameters. Returns ------- dict Dictionary with keys 'E_a' (activation energy) and 'T_ref' (reference temperature) Raises ------ ValueError If method is not 'arrhenius' or E_a is not set """ if self.method != "arrhenius": raise ValueError( f"Arrhenius parameters not available for method '{self.method}'" ) if self.E_a is None: raise ValueError("E_a (activation energy) not set") return { "E_a": self.E_a, "T_ref": self.T_ref, }
[docs] def get_shift_factors_array( self, temperatures: list[float] | np.ndarray | None = None ) -> tuple[np.ndarray, np.ndarray]: """Get shift factors as arrays for plotting and analysis. Parameters ---------- temperatures : list or ndarray, optional Temperatures in Kelvin. If None, uses temperatures from the last mastercurve creation (stored in ``shift_factors_``). Returns ------- temperatures : ndarray Array of temperatures in Kelvin (sorted) shift_factors : ndarray Array of shift factors corresponding to temperatures Raises ------ ValueError If temperatures is None and no shift factors have been computed Examples -------- >>> mc = Mastercurve(reference_temp=298.15, method='wlf') >>> temps, shifts = mc.get_shift_factors_array([273.15, 298.15, 323.15]) >>> import matplotlib.pyplot as plt >>> plt.plot(temps - 273.15, np.log10(shifts)) """ if temperatures is None: # Use stored shift factors from last mastercurve creation if self.shift_factors_ is None: raise ValueError( "No shift factors available. Either provide temperatures or " "create a mastercurve first." ) # Extract from stored shift factors temps_array = np.array(sorted(self.shift_factors_.keys())) shifts_array = np.array([self.shift_factors_[T] for T in temps_array]) else: # Calculate shift factors for provided temperatures temps_array = np.array(temperatures) # Sort by temperature sort_idx = np.argsort(temps_array) temps_array = temps_array[sort_idx] # Calculate shift factors shifts_array = np.array( [self.get_shift_factor(float(T)) for T in temps_array] ) return temps_array, shifts_array
def _transform_single(self, data: RheoData) -> RheoData: """Apply horizontal shift to single-temperature data. Parameters ---------- data : RheoData Single-temperature data to shift Returns ------- RheoData Shifted data Raises ------ ValueError If temperature metadata is missing """ # Get temperature from metadata _meta_ts = data.metadata or {} if "temperature" not in _meta_ts: raise ValueError("Temperature must be in metadata for mastercurve shifting") T = _meta_ts["temperature"] # Get shift factor a_T = self.get_shift_factor(T) # Apply horizontal shift (frequency or time shift) x_shifted = data.x * a_T # type: ignore[operator] # Apply vertical shift if requested y_shifted = data.y if self.vertical_shift: # For temperature-dependent modulus: G(T) ~ rho(T) * T # Vertical shift factor: b_T = rho(T) * T / (rho(T_ref) * T_ref) # Simplified: b_T = T / T_ref b_T = T / self.T_ref y_shifted = y_shifted * b_T # Create metadata new_metadata = _meta_ts.copy() new_metadata.update( { "transform": "mastercurve", "reference_temperature": self.T_ref, "shift_method": self.method, "horizontal_shift": float(a_T), "vertical_shift": float(T / self.T_ref) if self.vertical_shift else 1.0, } ) return RheoData( x=x_shifted, # type: ignore[arg-type] y=y_shifted, x_units=data.x_units, y_units=data.y_units, domain=data.domain, metadata=new_metadata, validate=False, ) def _transform( self, data: RheoData | list[RheoData] ) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]: """Apply horizontal shift to single-temperature data or create mastercurve. Parameters ---------- data : RheoData or list of RheoData Single-temperature data to shift, or list of datasets for mastercurve Returns ------- RheoData or tuple of (RheoData, dict) If data is a single RheoData: returns shifted data If data is a list: returns (mastercurve, ``shift_factors``) Raises ------ ValueError If temperature metadata is missing """ # Determine input shape for logging if isinstance(data, list): input_shape = (len(data),) else: input_shape = (len(data.x),) if hasattr(data.x, "__len__") else (1,) # type: ignore[arg-type] with log_transform( logger, "mastercurve", input_shape=input_shape, method=self.method, reference_temp=self.T_ref, auto_shift=self._auto_shift, ) as ctx: # Handle list of datasets (create mastercurve) if isinstance(data, list): result = self.create_mastercurve(data, return_shifts=True) if isinstance(result, tuple): mastercurve, shift_factors = result ctx["output_shape"] = (len(mastercurve.x),) # type: ignore[arg-type] ctx["n_temperatures"] = len(shift_factors) return result # Handle single dataset result = self._transform_single(data) ctx["output_shape"] = (len(result.x),) # type: ignore[arg-type] return result
[docs] def create_mastercurve( self, datasets: list[RheoData], merge: bool = True, return_shifts: bool = False ) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]: """Create mastercurve from multiple temperature datasets. Parameters ---------- datasets : list of RheoData List of datasets at different temperatures merge : bool, default=True If True, merge all shifted data into single RheoData. If False, return list of shifted datasets. return_shifts : bool, default=False If True, return tuple of (mastercurve, shift_factors). Only valid when merge=True. Returns ------- RheoData or list of RheoData or tuple If merge=True and return_shifts=False: RheoData If merge=False: list of RheoData If merge=True and return_shifts=True: (RheoData, dict of shift factors) Raises ------ ValueError If datasets don't have temperature metadata or if return_shifts=True with merge=False """ from rheojax.core.data import RheoData logger.debug( "Creating mastercurve", n_datasets=len(datasets), merge=merge, return_shifts=return_shifts, ) if not datasets: raise ValueError( "create_mastercurve requires at least one dataset. " "Received an empty list." ) if return_shifts and not merge: logger.error( "Invalid configuration: return_shifts=True requires merge=True" ) raise ValueError("return_shifts=True requires merge=True") # Extract temperatures and sort datasets temperatures = [] for data in datasets: _dmeta = data.metadata or {} if "temperature" not in _dmeta: raise ValueError("All datasets must have 'temperature' in metadata") temperatures.append(_dmeta["temperature"]) # Sort by temperature temp_indices = np.argsort(temperatures) datasets = [datasets[i] for i in temp_indices] temperatures = [temperatures[i] for i in temp_indices] # Find reference temperature index ref_temp_idx = np.argmin(np.abs(np.array(temperatures) - self.T_ref)) # Compute shift factors if self._auto_shift: # Use automatic shift factor calculation log_aT_array = self._compute_auto_shift_factors(datasets, int(ref_temp_idx)) shift_factors = { T: 10.0**log_aT for T, log_aT in zip(temperatures, log_aT_array, strict=False) } else: # Use manual WLF/Arrhenius/manual method shift_factors = {} for T in temperatures: a_T = self.get_shift_factor(T) shift_factors[T] = float(a_T) # Shift all datasets shifted_datasets = [] for data, T in zip(datasets, temperatures, strict=False): # Get shift factor a_T = shift_factors[T] # Apply horizontal shift x_shifted = data.x * a_T # Apply vertical shift if requested y_shifted = data.y if self.vertical_shift: b_T = T / self.T_ref y_shifted = y_shifted * b_T # Create metadata — use THIS dataset's metadata (not the loop variable # _dmeta which leaks from the temperature-extraction loop above and # would give every shifted dataset the metadata of the last dataset). new_metadata = (data.metadata or {}).copy() new_metadata.update( { "transform": "mastercurve", "reference_temperature": self.T_ref, "shift_method": "auto" if self._auto_shift else self.method, "horizontal_shift": float(a_T), "vertical_shift": ( float(T / self.T_ref) if self.vertical_shift else 1.0 ), } ) shifted = RheoData( x=x_shifted, y=y_shifted, x_units=data.x_units, y_units=data.y_units, domain=data.domain, metadata=new_metadata, validate=False, ) shifted_datasets.append(shifted) # If not merging, return list if not merge: return shifted_datasets # Merge all shifted data all_x = [] all_y = [] all_temps = [] for data, T in zip(shifted_datasets, temperatures, strict=False): x_data = data.x if isinstance(data.x, np.ndarray) else np.array(data.x) y_data = data.y if isinstance(data.y, np.ndarray) else np.array(data.y) all_x.append(x_data) all_y.append(y_data) all_temps.extend([T] * len(x_data)) # Concatenate merged_x = np.concatenate(all_x) merged_y = np.concatenate(all_y) merged_temps = np.array(all_temps) # Sort by x-axis sort_idx = np.argsort(merged_x) merged_x = merged_x[sort_idx] merged_y = merged_y[sort_idx] merged_temps = merged_temps[sort_idx] # Create merged metadata merged_metadata = { "transform": "mastercurve", "reference_temperature": self.T_ref, "shift_method": "auto" if self._auto_shift else self.method, "temperatures": temperatures, "n_datasets": len(datasets), "source_temperatures": merged_temps, "shift_factors": shift_factors, } mastercurve = RheoData( x=merged_x, y=merged_y, x_units=datasets[0].x_units if datasets else None, y_units=datasets[0].y_units if datasets else None, domain=datasets[0].domain if datasets else "frequency", metadata=merged_metadata, validate=False, ) # Store shift factors for later retrieval self.shift_factors_ = shift_factors if return_shifts: return mastercurve, shift_factors return mastercurve
[docs] def compute_overlap_error(self, datasets: list[RheoData]) -> float: """Compute overlap error for multi-temperature data. This metric quantifies how well the datasets collapse onto a mastercurve. Lower values indicate better superposition. Parameters ---------- datasets : list of RheoData List of datasets at different temperatures Returns ------- float Overlap error (normalized RMSE in overlap regions) """ # Create shifted datasets shifted_datasets = self.create_mastercurve(datasets, merge=False) if not isinstance(shifted_datasets, list): shifted_datasets = [shifted_datasets] # type: ignore[list-item] # Find overlapping regions and compute RMSE total_error = 0.0 n_overlaps = 0 for i in range(len(shifted_datasets)): for j in range(i + 1, len(shifted_datasets)): data_i = shifted_datasets[i] data_j = shifted_datasets[j] # Find overlap region x_i = ( data_i.x if isinstance(data_i.x, np.ndarray) else np.array(data_i.x) ) x_j = ( data_j.x if isinstance(data_j.x, np.ndarray) else np.array(data_j.x) ) x_min = max(x_i.min(), x_j.min()) x_max = min(x_i.max(), x_j.max()) if x_max <= x_min: continue # No overlap # Interpolate both datasets to common x-axis in overlap region x_common = np.linspace(x_min, x_max, 50) sort_i = np.argsort(x_i) sort_j = np.argsort(x_j) y_i_raw = np.asarray(data_i.y) y_j_raw = np.asarray(data_j.y) # T-22: Handle complex G* data — compute overlap error # on both G' and G'' independently, then combine RMSE. if np.iscomplexobj(y_i_raw) or np.iscomplexobj(y_j_raw): y_i_raw = np.asarray(y_i_raw, dtype=np.complex128) y_j_raw = np.asarray(y_j_raw, dtype=np.complex128) # Interpolate real and imag parts separately y_i_real = np.interp( x_common, x_i[sort_i], np.real(y_i_raw)[sort_i] ) y_i_imag = np.interp( x_common, x_i[sort_i], np.imag(y_i_raw)[sort_i] ) y_j_real = np.interp( x_common, x_j[sort_j], np.real(y_j_raw)[sort_j] ) y_j_imag = np.interp( x_common, x_j[sort_j], np.imag(y_j_raw)[sort_j] ) error_real = np.mean((y_i_real - y_j_real) ** 2) error_imag = np.mean((y_i_imag - y_j_imag) ** 2) error = np.sqrt((error_real + error_imag) / 2.0) else: y_i_interp = np.interp( x_common, x_i[sort_i], y_i_raw[sort_i], ) y_j_interp = np.interp( x_common, x_j[sort_j], y_j_raw[sort_j], ) error = np.sqrt(np.mean((y_i_interp - y_j_interp) ** 2)) total_error += error n_overlaps += 1 if n_overlaps == 0: return float("inf") return total_error / n_overlaps
[docs] def optimize_wlf_parameters( self, datasets: list[RheoData], initial_C1: float = 17.44, initial_C2: float = 51.6, ) -> tuple[float, float]: """Optimize WLF parameters to minimize overlap error. Parameters ---------- datasets : list of RheoData Multi-temperature datasets initial_C1 : float Initial guess for C1 initial_C2 : float Initial guess for C2 Returns ------- C1_opt : float Optimized C1 parameter C2_opt : float Optimized C2 parameter Note ---- Uses scipy.optimize.minimize (Nelder-Mead) because the objective function compute_overlap_error() uses NumPy interpolation which is not JAX-traceable. This is acceptable per Technical Guidelines as it's not in a hot path and is called only once during WLF parameter fitting. """ from scipy.optimize import minimize def objective(params): """Objective function: overlap error.""" C1, C2 = params self.C1 = C1 self.C2 = C2 return self.compute_overlap_error(datasets) # Optimize using Nelder-Mead (derivative-free, appropriate for non-JAX objective) result = minimize( objective, x0=[initial_C1, initial_C2], method="Nelder-Mead", ) C1_opt, C2_opt = result.x self.C1 = C1_opt self.C2 = C2_opt return C1_opt, C2_opt
__all__ = ["Mastercurve"]