Source code for gunz_ml.callbacks.early_stopping_optuna

"""
A PyTorch Lightning callback for Optuna pruning.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.1"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping

# =============================================================================
# MAIN CLASS
# =============================================================================
[docs] class EarlyStoppingOptuna(EarlyStopping): """ A PyTorch Lightning callback that prunes an Optuna trial if the early stopping condition is met. This callback inherits from `pytorch_lightning.callbacks.EarlyStopping` and overrides the `_run_early_stopping_check` method. When the early stopping condition is fulfilled, it raises `optuna.exceptions.TrialPruned` to signal to Optuna that the trial should be pruned. Parameters ---------- monitor : str The quantity to be monitored (e.g., 'val_loss'). min_delta : float, optional Minimum change in the monitored quantity to qualify as an improvement, by default 0.0. patience : int, optional Number of checks with no improvement after which training will be stopped, by default 3. verbose : bool, optional If True, prints a message for each pruning, by default False. mode : str, optional One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing, by default "min". strict : bool, optional Whether to crash the training if `monitor` is not found in the validation metrics, by default True. check_finite : bool, optional When set to `True`, stops training when the monitor becomes NaN or infinite, by default True. stopping_threshold : float, optional Stop training immediately once the monitored quantity reaches this threshold, by default None. divergence_threshold : float, optional Stop training as soon as the monitored quantity becomes worse than this threshold, by default None. check_on_train_epoch_end : bool, optional Whether to run early stopping at the end of the training epoch. If this is `False`, then the check runs at the end of the validation epoch, by default None. """ def __init__( self, #? --- Monitoring Configuration --- monitor: str, min_delta: float = 0.0, patience: int = 3, verbose: bool = False, mode: str = "min", strict: bool = True, check_finite: bool = True, #? --- Thresholding Configuration --- stopping_threshold: t.Optional[float] = None, divergence_threshold: t.Optional[float] = None, #? --- Timing Configuration --- check_on_train_epoch_end: t.Optional[bool] = None, ): super().__init__( monitor=monitor, min_delta=min_delta, patience=patience, verbose=verbose, mode=mode, strict=strict, check_finite=check_finite, stopping_threshold=stopping_threshold, divergence_threshold=divergence_threshold, check_on_train_epoch_end=check_on_train_epoch_end, ) def _run_early_stopping_check(self, trainer: L.Trainer) -> None: """ Checks whether the early stopping condition is met and if so, raises `optuna.exceptions.TrialPruned`. Parameters ---------- trainer : L.Trainer The PyTorch Lightning trainer instance. """ logs = trainer.callback_metrics if not self._validate_condition_metric(logs): return current = logs[self.monitor].squeeze() should_stop, reason = self._evaluate_stopping_criteria(current) if should_stop: if self.verbose: print(f"Trial pruned: {reason}") raise optuna.exceptions.TrialPruned(reason)