Source code for gunz_ml.callbacks.optuna

import optuna
import warnings
import lightning as L
from lightning.pytorch import Trainer
# from lightning.pytorch.trainer.trainer import Trainer

[docs] class PyTorchLightningPruningCallback(L.Callback): """ A PyTorch Lightning callback for Optuna pruning. Parameters ---------- trial : optuna.trial.Trial The Optuna trial object. monitor : str The metric to monitor for pruning. Attributes ---------- _trial : optuna.trial.Trial The Optuna trial object. monitor : str The metric to monitor for pruning. is_ddp_backend : bool Whether to use DDP backend. Defaults to `False`. """ def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: super().__init__() self._trial = trial self.monitor = monitor self.is_ddp_backend = False
[docs] def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: """ Called when the training epoch ends. Parameters ---------- trainer : L.Trainer The PyTorch Lightning trainer instance. pl_module : L.LightningModule The PyTorch Lightning module. Raises ------ optuna.exceptions.TrialPruned If the trial should be pruned based on the intermediate value. """ epoch = pl_module.current_epoch current_score = trainer.callback_metrics.get(self.monitor) self._trial.report(current_score, epoch) if self._trial.should_prune(): raise optuna.TrialPruned()
[docs] class OptunaPruningCallback(L.Callback): """ An Optuna pruning callback for PyTorch Lightning running on validation end. Parameters ---------- trial : optuna.trial.Trial The Optuna trial object. metric : str The metric to monitor for pruning. sub_metric : str, optional The sub-metric to monitor for pruning, if `metric` is a dictionary. Defaults to `None`. start_epoch : int, optional The epoch to start pruning. Defaults to 1. Attributes ---------- _trial : optuna.trial.Trial The Optuna trial object. metric : str The metric to monitor for pruning. sub_metric : str, optional The sub-metric to monitor for pruning. start_epoch : int The epoch to start pruning. """ def __init__(self, trial: optuna.trial.Trial, metric: str, sub_metric: str = None, start_epoch: int = 1) -> None: super().__init__() self._trial = trial self.metric = metric self.sub_metric = sub_metric self.start_epoch = start_epoch
[docs] def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: """ Called when the validation loop ends. Parameters ---------- trainer : L.Trainer The PyTorch Lightning trainer instance. pl_module : L.LightningModule The PyTorch Lightning module. Raises ------ optuna.exceptions.TrialPruned If the trial should be pruned. """ epoch = pl_module.current_epoch if epoch < self.start_epoch: return current_score = trainer.callback_metrics.get(self.metric) if isinstance(current_score, dict): current_score = current_score.get(self.sub_metric) if current_score is None: message = [ "The metric '{}'-'{}' is not in the evaluation logs for pruning.".format(self.metric, self.sub_metric), "Please make sure you set the correct metric name.", "Available: {}".format(trainer.callback_metrics.keys()) ] warnings.warn(" ".join(message)) return self._trial.report(current_score, step=epoch) if self._trial.should_prune(): message = "Trial was pruned at epoch {}.".format(epoch) raise optuna.TrialPruned(message)