import optuna
import warnings
import lightning as L
from lightning.pytorch import Trainer
# from lightning.pytorch.trainer.trainer import Trainer
[docs]
class PyTorchLightningPruningCallback(L.Callback):
"""
A PyTorch Lightning callback for Optuna pruning.
Parameters
----------
trial : optuna.trial.Trial
The Optuna trial object.
monitor : str
The metric to monitor for pruning.
Attributes
----------
_trial : optuna.trial.Trial
The Optuna trial object.
monitor : str
The metric to monitor for pruning.
is_ddp_backend : bool
Whether to use DDP backend. Defaults to `False`.
"""
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__()
self._trial = trial
self.monitor = monitor
self.is_ddp_backend = False
[docs]
def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
"""
Called when the training epoch ends.
Parameters
----------
trainer : L.Trainer
The PyTorch Lightning trainer instance.
pl_module : L.LightningModule
The PyTorch Lightning module.
Raises
------
optuna.exceptions.TrialPruned
If the trial should be pruned based on the intermediate value.
"""
epoch = pl_module.current_epoch
current_score = trainer.callback_metrics.get(self.monitor)
self._trial.report(current_score, epoch)
if self._trial.should_prune():
raise optuna.TrialPruned()
[docs]
class OptunaPruningCallback(L.Callback):
"""
An Optuna pruning callback for PyTorch Lightning running on validation end.
Parameters
----------
trial : optuna.trial.Trial
The Optuna trial object.
metric : str
The metric to monitor for pruning.
sub_metric : str, optional
The sub-metric to monitor for pruning, if `metric` is a dictionary.
Defaults to `None`.
start_epoch : int, optional
The epoch to start pruning. Defaults to 1.
Attributes
----------
_trial : optuna.trial.Trial
The Optuna trial object.
metric : str
The metric to monitor for pruning.
sub_metric : str, optional
The sub-metric to monitor for pruning.
start_epoch : int
The epoch to start pruning.
"""
def __init__(self, trial: optuna.trial.Trial, metric: str, sub_metric: str = None, start_epoch: int = 1) -> None:
super().__init__()
self._trial = trial
self.metric = metric
self.sub_metric = sub_metric
self.start_epoch = start_epoch
[docs]
def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
"""
Called when the validation loop ends.
Parameters
----------
trainer : L.Trainer
The PyTorch Lightning trainer instance.
pl_module : L.LightningModule
The PyTorch Lightning module.
Raises
------
optuna.exceptions.TrialPruned
If the trial should be pruned.
"""
epoch = pl_module.current_epoch
if epoch < self.start_epoch:
return
current_score = trainer.callback_metrics.get(self.metric)
if isinstance(current_score, dict):
current_score = current_score.get(self.sub_metric)
if current_score is None:
message = [
"The metric '{}'-'{}' is not in the evaluation logs for pruning.".format(self.metric, self.sub_metric),
"Please make sure you set the correct metric name.",
"Available: {}".format(trainer.callback_metrics.keys())
]
warnings.warn(" ".join(message))
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)