# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Early Stopping
^^^^^^^^^^^^^^
Monitor a metric and stop training when it stops improving.
"""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
log = logging.getLogger(__name__)
[docs]
class EarlyStopping(Callback):
r"""
Monitor a metric and stop training when it stops improving.
Parameters
----------
monitor : str
Quantity to be monitored.
min_delta : float
Minimum change in the monitored quantity to qualify as an improvement,
i.e. an absolute change of less than or equal to `min_delta`,
will count as no improvement.
patience : int
Number of checks with no improvement after which training will be stopped.
Under the default configuration, one check happens after every training epoch.
However, the frequency of validation can be modified by setting various
parameters on the ``Trainer``, for example ``check_val_every_n_epoch``
and ``val_check_interval``.
verbose : bool
Verbosity mode.
mode : str
One of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the
quantity monitored has stopped decreasing and in ``'max'`` mode it will stop
when the quantity monitored has stopped increasing.
strict : bool
Whether to crash the training if `monitor` is not found in the validation metrics.
check_finite : bool
When set ``True``, stops training when the monitor becomes NaN or infinite.
stopping_threshold : float, optional
Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold : float, optional
Stop training as soon as the monitored quantity becomes worse than this threshold.
check_on_train_epoch_end : bool, optional
Whether to run early stopping at the end of the training epoch. If this
is ``False``, then the check runs at the end of the validation.
log_rank_zero_only : bool
When set ``True``, logs the status of the early stopping callback only for rank 0 process.
Raises
------
MisconfigurationException
If ``mode`` is none of ``"min"`` or ``"max"``.
RuntimeError
If the metric ``monitor`` is not available.
Examples
--------
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])
Notes
-----
The `patience` parameter counts the number of validation checks with
no improvement, and not the number of training epochs. Therefore, with
parameters ``check_val_every_n_epoch=10`` and ``patience=3``, the
trainer will perform at least 40 training epochs before being stopped.
.. tip::
Saving and restoring multiple early stopping callbacks at the same time
is supported under variation in the following arguments:
*monitor, mode*
"""
mode_dict = {"min": torch.lt, "max": torch.gt}
order_dict = {"min": "<", "max": ">"}
def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: Optional[bool] = None,
log_rank_zero_only: bool = False,
):
super().__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.mode = mode
self.strict = strict
self.check_finite = check_finite
self.stopping_threshold = stopping_threshold
self.divergence_threshold = divergence_threshold
self.wait_count = 0
self.stopped_epoch = 0
self._check_on_train_epoch_end = check_on_train_epoch_end
self.log_rank_zero_only = log_rank_zero_only
if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
@property
def state_key(self) -> str:
"""
Generate a unique identifier for the callback state.
Returns
-------
str
The state key based on the monitored metric and mode.
"""
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
[docs]
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""
Called when fit or test begins.
Parameters
----------
trainer : pl.Trainer
The Trainer instance.
pl_module : pl.LightningModule
The LightningModule instance.
stage : str
The stage (e.g., 'fit', 'validate', 'test', 'predict').
"""
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
monitor_val = logs.get(self.monitor)
error_msg = (
f"Early stopping conditioned on metric `{self.monitor}` which is not available."
" Pass in or modify your `EarlyStopping` callback to use any of the following:"
f' `{"`, `".join(list(logs.keys()))}`'
)
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
rank_zero_warn(error_msg, category=RuntimeWarning)
return False
return True
@property
def monitor_op(self) -> Callable:
"""
Get the comparison operator based on the mode.
Returns
-------
Callable
Comparison function (e.g., torch.lt or torch.gt).
"""
return self.mode_dict[self.mode]
[docs]
def state_dict(self) -> Dict[str, Any]:
"""
Get the callback state dictionary.
Returns
-------
Dict[str, Any]
The state dictionary containing wait count, stopped epoch, best score, and patience.
"""
return {
"wait_count": self.wait_count,
"stopped_epoch": self.stopped_epoch,
"best_score": self.best_score,
"patience": self.patience,
}
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the callback state from a dictionary.
Parameters
----------
state_dict : Dict[str, Any]
The state dictionary to restore from.
"""
self.wait_count = state_dict["wait_count"]
self.stopped_epoch = state_dict["stopped_epoch"]
self.best_score = state_dict["best_score"]
self.patience = state_dict["patience"]
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
[docs]
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
Called when the train epoch ends.
Parameters
----------
trainer : pl.Trainer
The Trainer instance.
pl_module : pl.LightningModule
The LightningModule instance.
"""
if self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)
[docs]
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
Called when the validation loop ends.
Parameters
----------
trainer : pl.Trainer
The Trainer instance.
pl_module : pl.LightningModule
The LightningModule instance.
"""
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)
def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
"""Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
logs = trainer.callback_metrics
if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run
logs
): # short circuit if metric not present
return
current = logs[self.monitor].squeeze()
should_stop, reason = self._evaluate_stopping_criteria(current)
# stop every ddp process if any world process decides to stop
should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason and self.verbose:
self._log_info(trainer, reason, self.log_rank_zero_only)
def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
should_stop = True
reason = (
f"Monitored metric {self.monitor} = {current} is not finite."
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
)
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
should_stop = True
reason = (
"Stopping threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
" Signaling Trainer to stop."
)
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
should_stop = True
reason = (
"Divergence threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
should_stop = True
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
)
return should_stop, reason
def _improvement_message(self, current: Tensor) -> str:
"""Formats a log message that informs the user about an improvement in the monitored score."""
if torch.isfinite(self.best_score):
msg = (
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
)
else:
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
return msg
@staticmethod
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
rank = trainer.global_rank if trainer.world_size > 1 else None
message = rank_prefixed_message(message, rank)
if rank is None or not log_rank_zero_only or rank == 0:
log.info(message)