"""
A PyTorch Lightning callback for Optuna pruning.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.1"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
# =============================================================================
# MAIN CLASS
# =============================================================================
[docs]
class EarlyStoppingOptuna(EarlyStopping):
"""
A PyTorch Lightning callback that prunes an Optuna trial if the early
stopping condition is met.
This callback inherits from `pytorch_lightning.callbacks.EarlyStopping`
and overrides the `_run_early_stopping_check` method. When the early
stopping condition is fulfilled, it raises `optuna.exceptions.TrialPruned`
to signal to Optuna that the trial should be pruned.
Parameters
----------
monitor : str
The quantity to be monitored (e.g., 'val_loss').
min_delta : float, optional
Minimum change in the monitored quantity to qualify as an improvement,
by default 0.0.
patience : int, optional
Number of checks with no improvement after which training will be
stopped, by default 3.
verbose : bool, optional
If True, prints a message for each pruning, by default False.
mode : str, optional
One of {'min', 'max'}. In 'min' mode, training will stop when the
quantity monitored has stopped decreasing; in 'max' mode it will stop
when the quantity monitored has stopped increasing, by default "min".
strict : bool, optional
Whether to crash the training if `monitor` is not found in the
validation metrics, by default True.
check_finite : bool, optional
When set to `True`, stops training when the monitor becomes NaN or
infinite, by default True.
stopping_threshold : float, optional
Stop training immediately once the monitored quantity reaches this
threshold, by default None.
divergence_threshold : float, optional
Stop training as soon as the monitored quantity becomes worse than
this threshold, by default None.
check_on_train_epoch_end : bool, optional
Whether to run early stopping at the end of the training epoch. If
this is `False`, then the check runs at the end of the validation
epoch, by default None.
"""
def __init__(
self,
#? --- Monitoring Configuration ---
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
#? --- Thresholding Configuration ---
stopping_threshold: t.Optional[float] = None,
divergence_threshold: t.Optional[float] = None,
#? --- Timing Configuration ---
check_on_train_epoch_end: t.Optional[bool] = None,
):
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
strict=strict,
check_finite=check_finite,
stopping_threshold=stopping_threshold,
divergence_threshold=divergence_threshold,
check_on_train_epoch_end=check_on_train_epoch_end,
)
def _run_early_stopping_check(self, trainer: L.Trainer) -> None:
"""
Checks whether the early stopping condition is met and if so,
raises `optuna.exceptions.TrialPruned`.
Parameters
----------
trainer : L.Trainer
The PyTorch Lightning trainer instance.
"""
logs = trainer.callback_metrics
if not self._validate_condition_metric(logs):
return
current = logs[self.monitor].squeeze()
should_stop, reason = self._evaluate_stopping_criteria(current)
if should_stop:
if self.verbose:
print(f"Trial pruned: {reason}")
raise optuna.exceptions.TrialPruned(reason)